only call `words_normalized` once

pull/72/head
Max Bachmann committed by GitHub
parent dcc10c5389
commit f3825cdeb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -106,16 +106,15 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"):
gt_text = extract(gt, textequiv_level=textequiv_level)
ocr_text = extract(ocr, textequiv_level=textequiv_level)
gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text)
cer, n_characters = character_error_rate_n(gt_text, ocr_text)
wer, n_words = word_error_rate_n(gt_text, ocr_text)
char_diff_report = gen_diff_report(
gt_text, ocr_text, css_prefix="c", joiner="", none="·"
)
gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text)
wer, n_words = word_error_rate_n(gt_words, ocr_words)
word_diff_report = gen_diff_report(
gt_words, ocr_words, css_prefix="w", joiner=" ", none=""
)

@ -53,6 +53,8 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True)
ocr_text = plain_extract(os.path.join(ocr_dir, ocr), include_filename_in_id=True)
gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text)
# Compute CER
l_cer, l_n_characters = character_error_rate_n(gt_text, ocr_text)
@ -64,7 +66,7 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
n_characters = n_characters + l_n_characters
# Compute WER
l_wer, l_n_words = word_error_rate_n(gt_text, ocr_text)
l_wer, l_n_words = word_error_rate_n(gt_words, ocr_words)
if wer is None:
wer, n_words = l_wer, l_n_words
else:
@ -76,8 +78,6 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
char_diff_report += gen_diff_report(
gt_text, ocr_text, css_prefix="l{0}-c".format(k), joiner="", none="·"
)
gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text)
word_diff_report += gen_diff_report(
gt_words, ocr_words, css_prefix="l{0}-w".format(k), joiner=" ", none=""
)

Loading…
Cancel
Save