diff --git a/qurator/dinglehopper/character_error_rate.py b/qurator/dinglehopper/character_error_rate.py index f63a15f..05cc931 100644 --- a/qurator/dinglehopper/character_error_rate.py +++ b/qurator/dinglehopper/character_error_rate.py @@ -1,21 +1,36 @@ from __future__ import division import unicodedata +from typing import Tuple from uniseg.graphemecluster import grapheme_clusters from qurator.dinglehopper.edit_distance import distance -def character_error_rate(reference, compared): - d = distance(reference, compared) - if d == 0: - return 0 +def character_error_rate_n(reference, compared) -> Tuple[float, int]: + """ + Compute character error rate. + :return: character error rate and length of the reference + """ + d = distance(reference, compared) n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference)))) - if n == 0: - return float('inf') - return d/n + if d == 0: + return 0, n + if n == 0: + return float('inf'), n + return d/n, n # XXX Should we really count newlines here? + + +def character_error_rate(reference, compared) -> float: + """ + Compute character error rate. + + :return: character error rate + """ + cer, _ = character_error_rate_n(reference, compared) + return cer diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py index 7f9ea8f..08a36bb 100644 --- a/qurator/dinglehopper/cli.py +++ b/qurator/dinglehopper/cli.py @@ -57,8 +57,8 @@ def process(gt, ocr, report_prefix): gt_text = substitute_equivalences(gt_text) ocr_text = substitute_equivalences(ocr_text) - cer = character_error_rate(gt_text, ocr_text) - wer = word_error_rate(gt_text, ocr_text) + cer, n_characters = character_error_rate_n(gt_text, ocr_text) + wer, n_words = word_error_rate_n(gt_text, ocr_text) char_diff_report = gen_diff_report(gt_text, ocr_text, css_prefix='c', joiner='', none='ยท', align=align) @@ -88,7 +88,8 @@ def process(gt, ocr, report_prefix): template = env.get_template(template_fn) template.stream( gt=gt, ocr=ocr, - cer=cer, wer=wer, + cer=cer, n_characters=n_characters, + wer=wer, n_words=n_words, char_diff_report=char_diff_report, word_diff_report=word_diff_report ).dump(out_fn) diff --git a/qurator/dinglehopper/templates/report.json.j2 b/qurator/dinglehopper/templates/report.json.j2 index 62d3f77..62a242d 100644 --- a/qurator/dinglehopper/templates/report.json.j2 +++ b/qurator/dinglehopper/templates/report.json.j2 @@ -2,5 +2,7 @@ "gt": "{{ gt }}", "ocr": "{{ ocr }}", "cer": {{ cer|json_float }}, - "wer": {{ wer|json_float }} + "wer": {{ wer|json_float }}, + "n_characters": {{ n_characters }}, + "n_words": {{ n_words }} } diff --git a/qurator/dinglehopper/word_error_rate.py b/qurator/dinglehopper/word_error_rate.py index 2425200..7ed56e4 100644 --- a/qurator/dinglehopper/word_error_rate.py +++ b/qurator/dinglehopper/word_error_rate.py @@ -1,6 +1,7 @@ from __future__ import division import unicodedata +from typing import Tuple import uniseg.wordbreak @@ -44,7 +45,7 @@ def words_normalized(s): return words(unicodedata.normalize('NFC', s)) -def word_error_rate(reference, compared): +def word_error_rate_n(reference, compared) -> Tuple[float, int]: if isinstance(reference, str): reference_seq = list(words_normalized(reference)) compared_seq = list(words_normalized(compared)) @@ -53,11 +54,15 @@ def word_error_rate(reference, compared): compared_seq = list(compared) d = levenshtein(reference_seq, compared_seq) - if d == 0: - return 0 - n = len(reference_seq) + + if d == 0: + return 0, n if n == 0: - return float('inf') + return float('inf'), n + return d / n, n + - return d / n +def word_error_rate(reference, compared) -> float: + wer, _ = word_error_rate_n(reference, compared) + return wer