You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
112 lines
2.9 KiB
Python
112 lines
2.9 KiB
Python
import unicodedata
|
|
from typing import Tuple, Iterable
|
|
from multimethod import multimethod
|
|
|
|
import uniseg.wordbreak
|
|
|
|
from rapidfuzz.distance import Levenshtein
|
|
from . import ExtractedText
|
|
|
|
|
|
# Did we patch uniseg.wordbreak.word_break already?
|
|
word_break_patched = False
|
|
|
|
|
|
def patch_word_break():
|
|
"""
|
|
Patch uniseg.wordbreak.word_break to deal with our private use characters.
|
|
|
|
See also
|
|
https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
|
|
"""
|
|
old_word_break = uniseg.wordbreak.word_break
|
|
|
|
def new_word_break(c, index=0):
|
|
if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area
|
|
return "ALetter"
|
|
else:
|
|
return old_word_break(c, index)
|
|
|
|
uniseg.wordbreak.word_break = new_word_break
|
|
global word_break_patched
|
|
word_break_patched = True
|
|
|
|
|
|
@multimethod
|
|
def words(s: str):
|
|
"""Extract words from a string"""
|
|
|
|
global word_break_patched
|
|
if not word_break_patched:
|
|
patch_word_break()
|
|
|
|
|
|
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
|
|
def unwanted(c):
|
|
|
|
# See https://www.fileformat.info/info/unicode/category/index.htm
|
|
# and https://unicodebook.readthedocs.io/unicode.html#categories
|
|
unwanted_categories = "O", "M", "P", "Z", "S"
|
|
unwanted_subcategories = "Cc", "Cf"
|
|
|
|
subcat = unicodedata.category(c)
|
|
cat = subcat[0]
|
|
return cat in unwanted_categories or subcat in unwanted_subcategories
|
|
|
|
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using
|
|
# uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters."
|
|
for word in uniseg.wordbreak.words(s):
|
|
if all(unwanted(c) for c in word):
|
|
pass
|
|
else:
|
|
yield word
|
|
|
|
|
|
@multimethod
|
|
def words(s: ExtractedText):
|
|
return words(s.text)
|
|
|
|
|
|
@multimethod
|
|
def words_normalized(s: str):
|
|
return words(unicodedata.normalize("NFC", s))
|
|
|
|
|
|
@multimethod
|
|
def words_normalized(s: ExtractedText):
|
|
return words_normalized(s.text)
|
|
|
|
|
|
@multimethod
|
|
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
|
reference_seq = list(words_normalized(reference))
|
|
compared_seq = list(words_normalized(compared))
|
|
return word_error_rate_n(reference_seq, compared_seq)
|
|
|
|
|
|
@multimethod
|
|
def word_error_rate_n(
|
|
reference: ExtractedText, compared: ExtractedText
|
|
) -> Tuple[float, int]:
|
|
return word_error_rate_n(reference.text, compared.text)
|
|
|
|
|
|
@multimethod
|
|
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
|
|
reference_seq = list(reference)
|
|
compared_seq = list(compared)
|
|
|
|
d = Levenshtein.distance(reference_seq, compared_seq)
|
|
n = len(reference_seq)
|
|
|
|
if d == 0:
|
|
return 0, n
|
|
if n == 0:
|
|
return float("inf"), n
|
|
return d / n, n
|
|
|
|
|
|
def word_error_rate(reference, compared) -> float:
|
|
wer, _ = word_error_rate_n(reference, compared)
|
|
return wer
|