You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
dinglehopper/qurator/dinglehopper/word_error_rate.py

84 lines
2.5 KiB
Python

from __future__ import division
import unicodedata
from typing import Tuple
import uniseg.wordbreak
from .edit_distance import levenshtein
def words(s):
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
old_word_break = uniseg.wordbreak.word_break
def new_word_break(c, index=0):
if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area
return 'ALetter'
else:
return old_word_break(c, index)
uniseg.wordbreak.word_break = new_word_break
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
def unwanted(c):
# See https://www.fileformat.info/info/unicode/category/index.htm
# and https://unicodebook.readthedocs.io/unicode.html#categories
unwanted_categories = 'O', 'M', 'P', 'Z', 'S'
unwanted_subcategories = 'Cc', 'Cf'
subcat = unicodedata.category(c)
cat = subcat[0]
return cat in unwanted_categories or subcat in unwanted_subcategories
# XXX
from .cli import ExtractedText
if isinstance(s, ExtractedText):
s = s.text
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using
# uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters."
for word in uniseg.wordbreak.words(s):
if all(unwanted(c) for c in word):
pass
else:
yield word
def words_normalized(s):
# XXX
from .cli import ExtractedText
if isinstance(s, ExtractedText):
s = s.text
return words(unicodedata.normalize('NFC', s))
def word_error_rate_n(reference, compared) -> Tuple[float, int]:
# XXX
from .cli import ExtractedText
if isinstance(reference, ExtractedText):
reference = reference.text
if isinstance(compared, ExtractedText):
compared = compared.text
if isinstance(reference, str):
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
else:
reference_seq = list(reference)
compared_seq = list(compared)
d = levenshtein(reference_seq, compared_seq)
n = len(reference_seq)
if d == 0:
return 0, n
if n == 0:
return float('inf'), n
return d / n, n
def word_error_rate(reference, compared) -> float:
wer, _ = word_error_rate_n(reference, compared)
return wer