diff --git a/.screenshots/dinglehopper.png b/.screenshots/dinglehopper.png index df8bc04..9aa7d0e 100644 Binary files a/.screenshots/dinglehopper.png and b/.screenshots/dinglehopper.png differ diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py index c360c4e..ac3bbab 100644 --- a/qurator/dinglehopper/cli.py +++ b/qurator/dinglehopper/cli.py @@ -2,13 +2,12 @@ import json import os from collections import Counter from functools import partial -from typing import Callable, List, Tuple +from typing import Any, Callable, Dict, List, Tuple import click from jinja2 import Environment, FileSystemLoader from markupsafe import escape from ocrd_utils import initLogging -from uniseg.graphemecluster import grapheme_clusters from .align import seq_align from .config import Config @@ -20,21 +19,28 @@ from .metrics import ( word_accuracy, ) from .normalize import chars_normalized, words_normalized -from .ocr_files import text +from .ocr_files import extract def gen_count_report( - gt_text: str, ocr_text: str, split_fun: Callable[[str], Counter] + gt_text: ExtractedText, ocr_text: ExtractedText, split_fun: Callable[[str], Counter] ) -> List[Tuple[str, int, int]]: - gt_counter = Counter(split_fun(gt_text)) - ocr_counter = Counter(split_fun(ocr_text)) + gt_counter = Counter(split_fun(gt_text.text)) + ocr_counter = Counter(split_fun(ocr_text.text)) return [ ("".join(key), gt_counter[key], ocr_counter[key]) for key in sorted({*gt_counter.keys(), *ocr_counter.keys()}) ] -def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]: +def gen_diff_report( + gt_in: ExtractedText, + ocr_in: ExtractedText, + css_prefix: str = "c", + joiner: str = "", + none: str = "·", + split_fun=chars_normalized, +) -> Tuple[str, str]: gtx = "" ocrx = "" @@ -58,15 +64,8 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]: else: return f"{html_t}" - if isinstance(gt_in, ExtractedText): - if not isinstance(ocr_in, ExtractedText): - raise TypeError() - # XXX splitting should be done in ExtractedText - gt_things = list(grapheme_clusters(gt_in.text)) - ocr_things = list(grapheme_clusters(ocr_in.text)) - else: - gt_things = gt_in - ocr_things = ocr_in + gt_things = split_fun(gt_in.text) + ocr_things = split_fun(ocr_in.text) g_pos = 0 o_pos = 0 @@ -76,11 +75,10 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]: ocr_id = None if g != o: css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k) - if isinstance(gt_in, ExtractedText): - gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None - ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None - # Deletions and inserts only produce one id + None, UI must - # support this, i.e. display for the one id produced + gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None + ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None + # Deletions and inserts only produce one id + None, UI must + # support this, i.e. display for the one id produced gtx += joiner + format_thing(g, css_classes, gt_id) ocrx += joiner + format_thing(o, css_classes, ocr_id) @@ -93,18 +91,29 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]: return gtx, ocrx -def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results): +def generate_html_report( + gt: str, + ocr: str, + gt_text: ExtractedText, + ocr_text: ExtractedText, + report_prefix: str, + metrics_results: Dict, +): - metric_dict = { + metric_dict: Dict[str, Callable] = { "character_accuracy": partial( - gen_diff_report, css_prefix="c", joiner="", none="·" + gen_diff_report, + css_prefix="c", + joiner="", + none="·", + split_fun=chars_normalized, ), - "word_accuracy": lambda gt_text, ocr_text: gen_diff_report( - words_normalized(gt_text), - words_normalized(ocr_text), + "word_accuracy": partial( + gen_diff_report, css_prefix="w", joiner=" ", none="⋯", + split_fun=words_normalized, ), "bag_of_chars_accuracy": partial(gen_count_report, split_fun=chars_normalized), "bag_of_words_accuracy": partial(gen_count_report, split_fun=words_normalized), @@ -134,8 +143,8 @@ def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_resu ).dump(out_fn) -def generate_json_report(gt, ocr, report_prefix, metrics_results): - json_dict = {"gt": gt, "ocr": ocr} +def generate_json_report(gt: str, ocr: str, report_prefix: str, metrics_results: Dict): + json_dict: Dict[str, Any] = {"gt": gt, "ocr": ocr} for result in metrics_results.values(): json_dict[result.metric] = { key: value for key, value in result.get_dict().items() if key != "metric" @@ -153,8 +162,8 @@ def process( so we keep this undecorated version and use Click on a wrapper. """ - gt_text = text(gt, textequiv_level=textequiv_level) - ocr_text = text(ocr, textequiv_level=textequiv_level) + gt_text = extract(gt, textequiv_level=textequiv_level) + ocr_text = extract(ocr, textequiv_level=textequiv_level) metrics_results = {} if metrics: @@ -170,7 +179,7 @@ def process( metric = metric.strip() if metric not in metric_dict.keys(): raise ValueError(f"Unknown metric '{metric}'.") - result = metric_dict[metric](gt_text, ocr_text) + result = metric_dict[metric](gt_text.text, ocr_text.text) metrics_results[result.metric] = result generate_json_report(gt, ocr, report_prefix, metrics_results) diff --git a/qurator/dinglehopper/ocr_files.py b/qurator/dinglehopper/ocr_files.py index 0eab0c5..6584172 100644 --- a/qurator/dinglehopper/ocr_files.py +++ b/qurator/dinglehopper/ocr_files.py @@ -147,11 +147,7 @@ def plain_extract(filename): # XXX hardcoded SBB normalization -def plain_text(filename): - return plain_extract(filename).text - - -def extract(filename, *, textequiv_level="region"): +def extract(filename, *, textequiv_level="region") -> ExtractedText: """Extract the text from the given file. Supports PAGE, ALTO and falls back to plain text. diff --git a/qurator/dinglehopper/tests/test_ocr_files.py b/qurator/dinglehopper/tests/test_ocr_files.py index f48c59c..fbfd2ae 100644 --- a/qurator/dinglehopper/tests/test_ocr_files.py +++ b/qurator/dinglehopper/tests/test_ocr_files.py @@ -5,8 +5,7 @@ import textwrap import lxml.etree as ET from .util import working_directory -from ..ocr_files import alto_namespace, alto_text, page_namespace, page_text, \ - plain_text, text +from ..ocr_files import alto_namespace, alto_text, page_namespace, page_text, text data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") @@ -179,6 +178,6 @@ def test_plain(tmp_path): with open("ocr.txt", "w") as ocrf: ocrf.write("AAAAB") - result = plain_text("ocr.txt") + result = text("ocr.txt") expected = "AAAAB" assert result == expected