only call `words_normalized` once

pull/72/head
Max Bachmann 2 years ago committed by GitHub
parent dcc10c5389
commit f3825cdeb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -106,16 +106,15 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"):
gt_text = extract(gt, textequiv_level=textequiv_level) gt_text = extract(gt, textequiv_level=textequiv_level)
ocr_text = extract(ocr, textequiv_level=textequiv_level) ocr_text = extract(ocr, textequiv_level=textequiv_level)
gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text)
cer, n_characters = character_error_rate_n(gt_text, ocr_text) cer, n_characters = character_error_rate_n(gt_text, ocr_text)
wer, n_words = word_error_rate_n(gt_text, ocr_text)
char_diff_report = gen_diff_report( char_diff_report = gen_diff_report(
gt_text, ocr_text, css_prefix="c", joiner="", none="·" gt_text, ocr_text, css_prefix="c", joiner="", none="·"
) )
gt_words = words_normalized(gt_text) wer, n_words = word_error_rate_n(gt_words, ocr_words)
ocr_words = words_normalized(ocr_text)
word_diff_report = gen_diff_report( word_diff_report = gen_diff_report(
gt_words, ocr_words, css_prefix="w", joiner=" ", none="" gt_words, ocr_words, css_prefix="w", joiner=" ", none=""
) )

@ -53,6 +53,8 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True) gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True)
ocr_text = plain_extract(os.path.join(ocr_dir, ocr), include_filename_in_id=True) ocr_text = plain_extract(os.path.join(ocr_dir, ocr), include_filename_in_id=True)
gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text)
# Compute CER # Compute CER
l_cer, l_n_characters = character_error_rate_n(gt_text, ocr_text) l_cer, l_n_characters = character_error_rate_n(gt_text, ocr_text)
@ -64,7 +66,7 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
n_characters = n_characters + l_n_characters n_characters = n_characters + l_n_characters
# Compute WER # Compute WER
l_wer, l_n_words = word_error_rate_n(gt_text, ocr_text) l_wer, l_n_words = word_error_rate_n(gt_words, ocr_words)
if wer is None: if wer is None:
wer, n_words = l_wer, l_n_words wer, n_words = l_wer, l_n_words
else: else:
@ -76,8 +78,6 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
char_diff_report += gen_diff_report( char_diff_report += gen_diff_report(
gt_text, ocr_text, css_prefix="l{0}-c".format(k), joiner="", none="·" gt_text, ocr_text, css_prefix="l{0}-c".format(k), joiner="", none="·"
) )
gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text)
word_diff_report += gen_diff_report( word_diff_report += gen_diff_report(
gt_words, ocr_words, css_prefix="l{0}-w".format(k), joiner=" ", none="" gt_words, ocr_words, css_prefix="l{0}-w".format(k), joiner=" ", none=""
) )

Loading…
Cancel
Save