From e256526ea1d33e3673eb8bff466d8599277928ad Mon Sep 17 00:00:00 2001 From: Mike Gerber Date: Fri, 27 Oct 2023 20:55:37 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20calculation=20of=20score?= =?UTF-8?q?=5Fhint=20for=20edge=20cases,=20e.g.=20when=20CER=20is=20infini?= =?UTF-8?q?te?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If the CER is infinite, we can't calculate a score_hint as an int. Fall back to None in this case. --- qurator/dinglehopper/align.py | 19 +++++++++++++++++++ qurator/dinglehopper/cli.py | 8 +++++--- qurator/dinglehopper/cli_line_dirs.py | 6 ++++-- qurator/dinglehopper/tests/test_align.py | 7 ++++++- 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/qurator/dinglehopper/align.py b/qurator/dinglehopper/align.py index 968d931..07cbc8f 100644 --- a/qurator/dinglehopper/align.py +++ b/qurator/dinglehopper/align.py @@ -1,3 +1,6 @@ +import math +from math import ceil + from .edit_distance import * from rapidfuzz.distance import Levenshtein @@ -8,6 +11,22 @@ def align(t1, t2): return seq_align(s1, s2) +def score_hint(er: float, n: int) -> int | None: + """Calculate RapidFuzz score hint for a given error rate and count. + + Gives the score hint for the distance functions (= expected distance) or None if + the error rate is inf. + """ + assert not math.isnan(er) + try: + score_hint = int(ceil(er * n)) + except (OverflowError, ValueError): + # ceil(er * n) can be inf or NaN (for n == 0), so int() can throw an + # OverflowError and a ValueError. + score_hint = None + return score_hint + + def seq_align(s1, s2, score_hint=None): """Align general sequences.""" s1 = list(s1) diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py index ef101a4..4d4349c 100644 --- a/qurator/dinglehopper/cli.py +++ b/qurator/dinglehopper/cli.py @@ -8,7 +8,7 @@ from math import ceil from .character_error_rate import character_error_rate_n from .word_error_rate import word_error_rate_n, words_normalized -from .align import seq_align +from .align import seq_align, score_hint from .extracted_text import ExtractedText from .ocr_files import extract from .config import Config @@ -110,12 +110,14 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"): cer, n_characters = character_error_rate_n(gt_text, ocr_text) char_diff_report = gen_diff_report( - gt_text, ocr_text, css_prefix="c", joiner="", none="·", score_hint=int(ceil(cer * n_characters)) + gt_text, ocr_text, css_prefix="c", joiner="", none="·", + score_hint=score_hint(cer, n_characters) ) wer, n_words = word_error_rate_n(gt_words, ocr_words) word_diff_report = gen_diff_report( - gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯", score_hint=int(ceil(wer * n_words)) + gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯", + score_hint=score_hint(wer, n_words) ) env = Environment( diff --git a/qurator/dinglehopper/cli_line_dirs.py b/qurator/dinglehopper/cli_line_dirs.py index 06bbe39..01ba959 100644 --- a/qurator/dinglehopper/cli_line_dirs.py +++ b/qurator/dinglehopper/cli_line_dirs.py @@ -75,10 +75,12 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True): # Generate diff reports char_diff_report += gen_diff_report( - gt_text, ocr_text, css_prefix="l{0}-c".format(k), joiner="", none="·", score_hint=int(ceil(l_cer * l_n_characters)) + gt_text, ocr_text, css_prefix="l{0}-c".format(k), joiner="", none="·", + score_hint=score_hint(l_cer, l_n_characters) ) word_diff_report += gen_diff_report( - gt_words, ocr_words, css_prefix="l{0}-w".format(k), joiner=" ", none="⋯", score_hint=int(ceil(l_wer * l_n_words)) + gt_words, ocr_words, css_prefix="l{0}-w".format(k), joiner=" ", none="⋯", + score_hint=score_hint(l_wer, l_n_words)) ) env = Environment( diff --git a/qurator/dinglehopper/tests/test_align.py b/qurator/dinglehopper/tests/test_align.py index 96fc3c2..8e254e6 100644 --- a/qurator/dinglehopper/tests/test_align.py +++ b/qurator/dinglehopper/tests/test_align.py @@ -1,6 +1,7 @@ +import math import pytest from .util import unzip -from .. import align, seq_align, distance +from .. import align, seq_align, distance, score_hint def test_left_empty(): @@ -181,3 +182,7 @@ def test_lines_similar(): # Test __eq__ (i.e. is it a substitution or a similar string?) assert list(left)[0] == list(right)[0] + +def test_score_hint(): + assert score_hint(0.5, 23) == 12 # int(ceil()) + assert score_hint(math.inf, 12345) is None