You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
dinglehopper/qurator/dinglehopper/word_error_rate.py

114 lines
3.0 KiB
Python

from __future__ import division
import unicodedata
from typing import Tuple, Iterable
from multimethod import multimethod
import uniseg.wordbreak
from rapidfuzz.distance import Levenshtein
from . import ExtractedText
# Did we patch uniseg.wordbreak.word_break already?
word_break_patched = False
def patch_word_break():
"""
Patch uniseg.wordbreak.word_break to deal with our private use characters.
See also
https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
"""
old_word_break = uniseg.wordbreak.word_break
def new_word_break(c, index=0):
if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area
return "ALetter"
else:
return old_word_break(c, index)
uniseg.wordbreak.word_break = new_word_break
global word_break_patched
word_break_patched = True
@multimethod
def words(s: str):
"""Extract words from a string"""
global word_break_patched
if not word_break_patched:
patch_word_break()
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
def unwanted(c):
# See https://www.fileformat.info/info/unicode/category/index.htm
# and https://unicodebook.readthedocs.io/unicode.html#categories
unwanted_categories = "O", "M", "P", "Z", "S"
unwanted_subcategories = "Cc", "Cf"
subcat = unicodedata.category(c)
cat = subcat[0]
return cat in unwanted_categories or subcat in unwanted_subcategories
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using
# uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters."
for word in uniseg.wordbreak.words(s):
if all(unwanted(c) for c in word):
pass
else:
yield word
@multimethod
def words(s: ExtractedText):
return words(s.text)
@multimethod
def words_normalized(s: str):
return words(unicodedata.normalize("NFC", s))
@multimethod
def words_normalized(s: ExtractedText):
return words_normalized(s.text)
@multimethod
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
return word_error_rate_n(reference_seq, compared_seq)
@multimethod
def word_error_rate_n(
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text)
@multimethod
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)
d = Levenshtein.distance(reference_seq, compared_seq)
n = len(reference_seq)
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
def word_error_rate(reference, compared) -> float:
wer, _ = word_error_rate_n(reference, compared)
return wer