diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py index dbf9b28..7b95c99 100644 --- a/qurator/dinglehopper/cli.py +++ b/qurator/dinglehopper/cli.py @@ -1,5 +1,8 @@ import json import os +from collections import Counter +from functools import partial +from typing import Callable, List, Tuple import click from jinja2 import Environment, FileSystemLoader @@ -10,13 +13,28 @@ from uniseg.graphemecluster import grapheme_clusters from .align import seq_align from .config import Config from .extracted_text import ExtractedText -from .metrics import bag_of_chars_accuracy, bag_of_words_accuracy, character_accuracy, \ - word_accuracy -from .normalize import words_normalized +from .metrics import ( + bag_of_chars_accuracy, + bag_of_words_accuracy, + character_accuracy, + word_accuracy, +) +from .normalize import chars_normalized, words_normalized from .ocr_files import text -def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): +def gen_count_report( + gt_text: str, ocr_text: str, split_fun: Callable[[str], Counter] +) -> List[Tuple[str, int, int]]: + gt_counter = Counter(split_fun(gt_text)) + ocr_counter = Counter(split_fun(ocr_text)) + return [ + ("".join(key), gt_counter[key], ocr_counter[key]) + for key in sorted({*gt_counter.keys(), *ocr_counter.keys()}) + ] + + +def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]: gtx = "" ocrx = "" @@ -36,7 +54,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_) if css_classes: - return f"{html_t}" + return f'{html_t}' else: return f"{html_t}" @@ -72,26 +90,30 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): if o is not None: o_pos += len(o) - return """ -
-
{}
-
{}
-
- """.format( - gtx, ocrx - ) + return gtx, ocrx def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results): - char_diff_report = gen_diff_report( - gt_text, ocr_text, css_prefix="c", joiner="", none="·" - ) - gt_words = words_normalized(gt_text) - ocr_words = words_normalized(ocr_text) - word_diff_report = gen_diff_report( - gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯" - ) + metric_dict = { + "character_accuracy": partial( + gen_diff_report, css_prefix="c", joiner="", none="·" + ), + "word_accuracy": lambda gt_text, ocr_text: gen_diff_report( + words_normalized(gt_text), + words_normalized(ocr_text), + css_prefix="w", + joiner=" ", + none="⋯", + ), + "bag_of_chars_accuracy": partial(gen_count_report, split_fun=chars_normalized), + "bag_of_words_accuracy": partial(gen_count_report, split_fun=words_normalized), + } + metrics_reports = {} + for metric in metrics_results.keys(): + if metric not in metric_dict.keys(): + raise ValueError(f"Unknown metric '{metric}'.") + metrics_reports[metric] = metric_dict[metric](gt_text, ocr_text) env = Environment( loader=FileSystemLoader( @@ -107,15 +129,14 @@ def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_resu template.stream( gt=gt, ocr=ocr, - char_diff_report=char_diff_report, - word_diff_report=word_diff_report, + metrics_reports=metrics_reports, metrics_results=metrics_results, ).dump(out_fn) def generate_json_report(gt, ocr, report_prefix, metrics_results): json_dict = {"gt": gt, "ocr": ocr} - for result in metrics_results: + for result in metrics_results.values(): json_dict[result.metric] = { key: value for key, value in result.get_dict().items() if key != "metric" } @@ -133,7 +154,7 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio gt_text = text(gt, textequiv_level=textequiv_level) ocr_text = text(ocr, textequiv_level=textequiv_level) - metrics_results = set() + metrics_results = {} if metrics: metric_dict = { "ca": character_accuracy, @@ -147,7 +168,8 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio metric = metric.strip() if metric not in metric_dict.keys(): raise ValueError(f"Unknown metric '{metric}'.") - metrics_results.add(metric_dict[metric](gt_text, ocr_text)) + result = metric_dict[metric](gt_text, ocr_text) + metrics_results[result.metric] = result generate_json_report(gt, ocr, report_prefix, metrics_results) generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results) diff --git a/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py index e54d73c..3da143d 100644 --- a/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py +++ b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py @@ -1,8 +1,6 @@ from collections import Counter -from unicodedata import normalize - -from uniseg.graphemecluster import grapheme_clusters +from ..normalize import chars_normalized from .utils import bag_accuracy, MetricResult, Weights @@ -19,8 +17,8 @@ def bag_of_chars_accuracy( :param weights: Weights/costs for editing operations. :return: Class representing the results of this metric. """ - reference_chars: Counter = Counter(grapheme_clusters(normalize("NFC", reference))) - compared_chars: Counter = Counter(grapheme_clusters(normalize("NFC", compared))) + reference_chars: Counter = Counter(chars_normalized(reference)) + compared_chars: Counter = Counter(chars_normalized(compared)) return bag_accuracy( reference_chars, compared_chars, weights, bag_of_chars_accuracy.__name__ ) diff --git a/qurator/dinglehopper/templates/report.html.j2 b/qurator/dinglehopper/templates/report.html.j2 index 46a952f..c1a3797 100644 --- a/qurator/dinglehopper/templates/report.html.j2 +++ b/qurator/dinglehopper/templates/report.html.j2 @@ -4,18 +4,18 @@ - +