You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

113 lines
3.2 KiB

import unicodedata
from typing import Generator, Iterable, Tuple, TypeVar
import uniseg.wordbreak
from multimethod import multimethod
from rapidfuzz.distance import Levenshtein
from .extracted_text import ExtractedText
T = TypeVar("T")
# Did we patch uniseg.wordbreak.word_break already?
word_break_patched = False
def patch_word_break():
Patch uniseg.wordbreak.word_break to deal with our private use characters.
See also
old_word_break = uniseg.wordbreak.word_break
def new_word_break(c, index=0):
if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area
return uniseg.wordbreak.WordBreak.ALETTER
return old_word_break(c, index)
uniseg.wordbreak.word_break = new_word_break
global word_break_patched
word_break_patched = True
def words(s: str) -> Generator[str, None, None]:
"""Extract words from a string"""
global word_break_patched
if not word_break_patched:
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
def unwanted(c):
# See
# and
unwanted_categories = "O", "M", "P", "Z", "S"
unwanted_subcategories = "Cc", "Cf"
subcat = unicodedata.category(c)
cat = subcat[0]
return cat in unwanted_categories or subcat in unwanted_subcategories
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on
# word boundaries using uniseg.wordbreak.words() and ignore all "words" that contain
# only whitespace, punctation "or similar characters."
for word in uniseg.wordbreak.words(s):
if all(unwanted(c) for c in word):
yield word
def _(s: ExtractedText) -> Generator[str, None, None]:
yield from words(s.text)
def words_normalized(s: str) -> Generator[str, None, None]:
yield from words(unicodedata.normalize("NFC", s))
def _(s: ExtractedText) -> Generator[str, None, None]:
yield from words_normalized(s.text)
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
wer, n = word_error_rate_n(reference_seq, compared_seq)
return wer, n
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
wer, n = word_error_rate_n(reference.text, compared.text)
return wer, n
def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)
d = Levenshtein.distance(reference_seq, compared_seq)
n = len(reference_seq)
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
def word_error_rate(reference: T, compared: T) -> float:
wer: float
wer, _ = word_error_rate_n(reference, compared)
return wer