Merge remote-tracking branch 'bertsky/normalized-cer' into v3-api

kba 1 week ago
commit e01071d936

@ -20,14 +20,7 @@ def character_error_rate_n(
:return: character error rate and length of the reference
"""
d = distance(reference, compared)
n = len(reference)
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
return distance(reference, compared), len(reference)
# XXX Should we really count newlines here?

@ -9,18 +9,18 @@ from .extracted_text import ExtractedText
@multimethod
def distance(seq1: List[str], seq2: List[str]) -> int:
def distance(seq1: List[str], seq2: List[str]) -> float:
"""Compute the Levenshtein edit distance between two lists of grapheme clusters.
This assumes that the grapheme clusters are already normalized.
Use distance(str, str) instead if you need to compare two Unicode strings.
"""
return Levenshtein.distance(seq1, seq2)
return Levenshtein.normalized_distance(seq1, seq2)
@distance.register
def _(s1: str, s2: str) -> int:
def _(s1: str, s2: str) -> float:
"""Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode
@ -29,12 +29,12 @@ def _(s1: str, s2: str) -> int:
"""
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", s1)))
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", s2)))
return Levenshtein.distance(seq1, seq2)
return Levenshtein.normalized_distance(seq1, seq2)
@distance.register
def _(s1: ExtractedText, s2: ExtractedText) -> int:
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)
def _(s1: ExtractedText, s2: ExtractedText) -> float:
return Levenshtein.normalized_distance(s1.grapheme_clusters, s2.grapheme_clusters)
def editops(word1, word2):

@ -14,9 +14,9 @@ def test_character_error_rate():
assert character_error_rate("Foo", "") == 3 / 3
assert character_error_rate("", "") == 0
assert math.isinf(character_error_rate("", "Foo"))
assert character_error_rate("", "Foo") == 3 / 3
assert character_error_rate("Foo", "Food") == 1 / 3
assert character_error_rate("Foo", "Food") == 1 / 4
assert character_error_rate("Fnord", "Food") == 2 / 5
assert character_error_rate("Müll", "Mull") == 1 / 4
assert character_error_rate("Abstand", "Sand") == 4 / 7

@ -6,8 +6,8 @@ from .. import distance
def test_distance():
assert distance("Fnord", "Food") == 2
assert distance("Müll", "Mull") == 1
assert distance("Fnord", "Food") == 2 / 5
assert distance("Müll", "Mull") == 1 / 4
word1 = unicodedata.normalize("NFC", "Schlyñ")
word2 = unicodedata.normalize("NFD", "Schlyñ") # Different, decomposed!
@ -21,4 +21,4 @@ def test_distance():
assert (
len(word2) == 7
) # This, OTOH, ends with LATIN SMALL LETTER M + COMBINING TILDE, 7 code points
assert distance(word1, word2) == 1
assert distance(word1, word2) == 1 / 6

@ -56,4 +56,4 @@ def test_character_error_rate_between_page_alto_2():
)
)
assert character_error_rate(gt, ocr) == 8 / 591 # Manually verified
assert character_error_rate(gt, ocr) == 8 / 594 # Manually verified

@ -32,11 +32,11 @@ def test_cli_json_cer_is_infinity(tmp_path):
with working_directory(tmp_path):
with open("gt.txt", "w") as gtf:
gtf.write("") # Empty to yield CER == inf
gtf.write("")
with open("ocr.txt", "w") as ocrf:
ocrf.write("Not important")
process("gt.txt", "ocr.txt", "report")
with open("report.json", "r") as jsonf:
j = json.load(jsonf)
assert j["cer"] == pytest.approx(float("inf"))
assert j["cer"] == pytest.approx(1.0)

@ -17,7 +17,7 @@ def test_distance_between_page_files():
# → 2 differences
gt = page_text(ET.parse(os.path.join(data_dir, "test-gt.page2018.xml")))
ocr = page_text(ET.parse(os.path.join(data_dir, "test-fake-ocr.page2018.xml")))
assert distance(gt, ocr) == 2
assert distance(gt, ocr) == 2 / 827
@pytest.mark.integration
@ -52,4 +52,4 @@ def test_distance_between_page_alto_2():
)
)
assert distance(gt, ocr) == 8 # Manually verified
assert distance(gt, ocr) == 8 / 594 # Manually verified

@ -12,9 +12,9 @@ from .util import working_directory
@pytest.mark.parametrize(
"gt_file_content,ocr_file_content,cer_expected",
[
("", "Lorem ipsum", math.inf),
("", "Lorem ipsum", 1.0),
("Lorem ipsum", "", 1.0),
("\ufeff", "Lorem ipsum", math.inf),
("\ufeff", "Lorem ipsum", 1.0),
("Lorem ipsum", "\ufeff", 1.0),
("", "", 0.0),
("\ufeff", "", 0.0),

@ -64,5 +64,5 @@ def test_word_error_rate_between_page_alto_2():
)
assert (
word_error_rate(gt, ocr) == 7 / gt_word_count
word_error_rate(gt, ocr) == 7 / (gt_word_count + 1)
) # Manually verified, 6 words are wrong, 1 got split (=2 errors)

@ -76,7 +76,7 @@ def test_word_error_rate():
)
assert word_error_rate("Dies ist ein Beispielsatz!", "") == 4 / 4
assert math.isinf(word_error_rate("", "Dies ist ein Beispielsatz!"))
assert word_error_rate("", "Dies ist ein Beispielsatz!") == 4 / 4
assert word_error_rate("", "") == 0
assert (

@ -96,15 +96,10 @@ def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)
d = Levenshtein.distance(reference_seq, compared_seq)
d = Levenshtein.normalized_distance(reference_seq, compared_seq)
n = len(reference_seq)
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
return d, n
def word_error_rate(reference: T, compared: T) -> float:
wer: float

Loading…
Cancel
Save