diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py
index dbf9b28..7b95c99 100644
--- a/qurator/dinglehopper/cli.py
+++ b/qurator/dinglehopper/cli.py
@@ -1,5 +1,8 @@
import json
import os
+from collections import Counter
+from functools import partial
+from typing import Callable, List, Tuple
import click
from jinja2 import Environment, FileSystemLoader
@@ -10,13 +13,28 @@ from uniseg.graphemecluster import grapheme_clusters
from .align import seq_align
from .config import Config
from .extracted_text import ExtractedText
-from .metrics import bag_of_chars_accuracy, bag_of_words_accuracy, character_accuracy, \
- word_accuracy
-from .normalize import words_normalized
+from .metrics import (
+ bag_of_chars_accuracy,
+ bag_of_words_accuracy,
+ character_accuracy,
+ word_accuracy,
+)
+from .normalize import chars_normalized, words_normalized
from .ocr_files import text
-def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
+def gen_count_report(
+ gt_text: str, ocr_text: str, split_fun: Callable[[str], Counter]
+) -> List[Tuple[str, int, int]]:
+ gt_counter = Counter(split_fun(gt_text))
+ ocr_counter = Counter(split_fun(ocr_text))
+ return [
+ ("".join(key), gt_counter[key], ocr_counter[key])
+ for key in sorted({*gt_counter.keys(), *ocr_counter.keys()})
+ ]
+
+
+def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]:
gtx = ""
ocrx = ""
@@ -36,7 +54,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_)
if css_classes:
- return f"{html_t}"
+ return f'{html_t}'
else:
return f"{html_t}"
@@ -72,26 +90,30 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
if o is not None:
o_pos += len(o)
- return """
-
- """.format(
- gtx, ocrx
- )
+ return gtx, ocrx
def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results):
- char_diff_report = gen_diff_report(
- gt_text, ocr_text, css_prefix="c", joiner="", none="·"
- )
- gt_words = words_normalized(gt_text)
- ocr_words = words_normalized(ocr_text)
- word_diff_report = gen_diff_report(
- gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯"
- )
+ metric_dict = {
+ "character_accuracy": partial(
+ gen_diff_report, css_prefix="c", joiner="", none="·"
+ ),
+ "word_accuracy": lambda gt_text, ocr_text: gen_diff_report(
+ words_normalized(gt_text),
+ words_normalized(ocr_text),
+ css_prefix="w",
+ joiner=" ",
+ none="⋯",
+ ),
+ "bag_of_chars_accuracy": partial(gen_count_report, split_fun=chars_normalized),
+ "bag_of_words_accuracy": partial(gen_count_report, split_fun=words_normalized),
+ }
+ metrics_reports = {}
+ for metric in metrics_results.keys():
+ if metric not in metric_dict.keys():
+ raise ValueError(f"Unknown metric '{metric}'.")
+ metrics_reports[metric] = metric_dict[metric](gt_text, ocr_text)
env = Environment(
loader=FileSystemLoader(
@@ -107,15 +129,14 @@ def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_resu
template.stream(
gt=gt,
ocr=ocr,
- char_diff_report=char_diff_report,
- word_diff_report=word_diff_report,
+ metrics_reports=metrics_reports,
metrics_results=metrics_results,
).dump(out_fn)
def generate_json_report(gt, ocr, report_prefix, metrics_results):
json_dict = {"gt": gt, "ocr": ocr}
- for result in metrics_results:
+ for result in metrics_results.values():
json_dict[result.metric] = {
key: value for key, value in result.get_dict().items() if key != "metric"
}
@@ -133,7 +154,7 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
gt_text = text(gt, textequiv_level=textequiv_level)
ocr_text = text(ocr, textequiv_level=textequiv_level)
- metrics_results = set()
+ metrics_results = {}
if metrics:
metric_dict = {
"ca": character_accuracy,
@@ -147,7 +168,8 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
metric = metric.strip()
if metric not in metric_dict.keys():
raise ValueError(f"Unknown metric '{metric}'.")
- metrics_results.add(metric_dict[metric](gt_text, ocr_text))
+ result = metric_dict[metric](gt_text, ocr_text)
+ metrics_results[result.metric] = result
generate_json_report(gt, ocr, report_prefix, metrics_results)
generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results)
diff --git a/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py
index e54d73c..3da143d 100644
--- a/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py
+++ b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py
@@ -1,8 +1,6 @@
from collections import Counter
-from unicodedata import normalize
-
-from uniseg.graphemecluster import grapheme_clusters
+from ..normalize import chars_normalized
from .utils import bag_accuracy, MetricResult, Weights
@@ -19,8 +17,8 @@ def bag_of_chars_accuracy(
:param weights: Weights/costs for editing operations.
:return: Class representing the results of this metric.
"""
- reference_chars: Counter = Counter(grapheme_clusters(normalize("NFC", reference)))
- compared_chars: Counter = Counter(grapheme_clusters(normalize("NFC", compared)))
+ reference_chars: Counter = Counter(chars_normalized(reference))
+ compared_chars: Counter = Counter(chars_normalized(compared))
return bag_accuracy(
reference_chars, compared_chars, weights, bag_of_chars_accuracy.__name__
)
diff --git a/qurator/dinglehopper/templates/report.html.j2 b/qurator/dinglehopper/templates/report.html.j2
index 46a952f..c1a3797 100644
--- a/qurator/dinglehopper/templates/report.html.j2
+++ b/qurator/dinglehopper/templates/report.html.j2
@@ -4,18 +4,18 @@
-
+