diff --git a/qurator/dinglehopper/__init__.py b/qurator/dinglehopper/__init__.py index fd309dc..dc45a8f 100644 --- a/qurator/dinglehopper/__init__.py +++ b/qurator/dinglehopper/__init__.py @@ -3,4 +3,8 @@ from .extracted_text import * from .character_error_rate import * from .word_error_rate import * from .align import * -from .flexible_character_accuracy import flexible_character_accuracy, split_matches +from .flexible_character_accuracy import ( + flexible_character_accuracy, + split_matches, + Match, +) diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py index b717618..46fc0b0 100644 --- a/qurator/dinglehopper/cli.py +++ b/qurator/dinglehopper/cli.py @@ -14,7 +14,7 @@ from .ocr_files import extract from .config import Config -def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None): +def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, matches=None): gtx = "" ocrx = "" @@ -42,7 +42,27 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None): else: return "{html_t}".format(html_t=html_t) - if isinstance(gt_in, ExtractedText): + ops, ocr_ids = None, None + if matches: + gt_things, ocr_things, ops = split_matches(matches) + # we have to reconstruct the order of the ocr because we mixed it for fca + ocr_lines = [match.ocr for match in matches] + ocr_lines_sorted = sorted(ocr_lines, key=lambda x: x.line + x.start / 10000) + + ocr_line_region_id = {} + pos = 0 + for ocr_line in ocr_lines_sorted: + if ocr_line.line not in ocr_line_region_id.keys(): + ocr_line_region_id[ocr_line.line] = ocr_in.segment_id_for_pos(pos) + pos += ocr_line.length + + ocr_ids = {None: None} + pos = 0 + for ocr_line in ocr_lines: + for _ in ocr_line.text: + ocr_ids[pos] = ocr_line_region_id[ocr_line.line] + pos += 1 + elif isinstance(gt_in, ExtractedText): if not isinstance(ocr_in, ExtractedText): raise TypeError() # XXX splitting should be done in ExtractedText @@ -61,10 +81,13 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None): if g != o: css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k) if isinstance(gt_in, ExtractedText): - gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None - ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None # Deletions and inserts only produce one id + None, UI must # support this, i.e. display for the one id produced + gt_id = gt_in.segment_id_for_pos(g_pos) if g else None + if ocr_ids: + ocr_id = ocr_ids[o_pos] + else: + ocr_id = ocr_in.segment_id_for_pos(o_pos) if o else None gtx += joiner + format_thing(g, css_classes, gt_id) ocrx += joiner + format_thing(o, css_classes, ocr_id) @@ -111,15 +134,9 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯" ) if "fca" in metrics: - fca, fca_matches = flexible_character_accuracy(gt_text.text, ocr_text.text) - fca_gt_segments, fca_ocr_segments, ops = split_matches(fca_matches) + fca, fca_matches = flexible_character_accuracy(gt_text, ocr_text) fca_diff_report = gen_diff_report( - fca_gt_segments, - fca_ocr_segments, - css_prefix="c", - joiner="", - none="·", - ops=ops, + gt_text, ocr_text, css_prefix="c", joiner="", none="·", matches=fca_matches ) def json_float(value): diff --git a/qurator/dinglehopper/flexible_character_accuracy.py b/qurator/dinglehopper/flexible_character_accuracy.py index 7865dd1..349384c 100644 --- a/qurator/dinglehopper/flexible_character_accuracy.py +++ b/qurator/dinglehopper/flexible_character_accuracy.py @@ -17,7 +17,9 @@ from functools import lru_cache, reduce from itertools import product, takewhile from typing import List, Tuple, Optional -from . import editops +from multimethod import multimethod + +from . import editops, ExtractedText if sys.version_info.minor == 5: from .flexible_character_accuracy_ds_35 import ( @@ -35,6 +37,22 @@ else: ) +@multimethod +def flexible_character_accuracy( + gt: ExtractedText, ocr: ExtractedText +) -> Tuple[float, List[Match]]: + """Calculate the flexible character accuracy. + + Reference: contains steps 1-7 of the flexible character accuracy algorithm. + + :param gt: The ground truth text. + :param ocr: The text to compare the ground truth with. + :return: Score between 0 and 1 and match objects. + """ + return flexible_character_accuracy(gt.text, ocr.text) + + +@multimethod def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]: """Calculate the flexible character accuracy. @@ -359,7 +377,7 @@ def split_matches(matches: List[Match]) -> Tuple[List[str], List[str], List[List :param matches: List of match objects. :return: List of ground truth segments, ocr segments and editing operations. """ - matches = sorted(matches, key=lambda x: x.gt.line + x.gt.start / 10000) + matches = sorted(matches, key=lambda m: m.gt.line + m.gt.start / 10000) line = 0 gt, ocr, ops = [], [], [] for match in matches: @@ -410,4 +428,4 @@ class Part(PartVersionSpecific): """ text = self.text[rel_start:rel_end] start = self.start + rel_start - return Part(text=text, line=self.line, start=start) + return Part(**{**self._asdict(), "text": text, "start": start}) diff --git a/qurator/dinglehopper/tests/test_flexible_character_accuracy.py b/qurator/dinglehopper/tests/test_flexible_character_accuracy.py index 2f6d702..3ade597 100644 --- a/qurator/dinglehopper/tests/test_flexible_character_accuracy.py +++ b/qurator/dinglehopper/tests/test_flexible_character_accuracy.py @@ -10,6 +10,7 @@ DOI: 10.1016/j.patrec.2020.02.003 """ import pytest +from lxml import etree as ET from ..flexible_character_accuracy import * @@ -101,11 +102,39 @@ def extended_case_to_text(gt, ocr): @pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) -def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_score): +def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score): score, _ = flexible_character_accuracy(gt, ocr) assert score == pytest.approx(all_line_score) +@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) +def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_score): + def get_extracted_text(text: str): + xml = '' + ns = "http://schema.primaresearch.org/PAGE/gts/pagecontent/2018-07-15" + + textline_tmpl = ( + '{1}' + "" + ) + xml_tmpl = '{0}{2}' + + textlines = [ + textline_tmpl.format(i, line) for i, line in enumerate(text.splitlines()) + ] + xml_text = xml_tmpl.format(xml, ns, "".join(textlines)) + root = ET.fromstring(xml_text) + extracted_text = ExtractedText.from_text_segment( + root, {"page": ns}, textequiv_level="line" + ) + return extracted_text + + gt_text = get_extracted_text(gt) + ocr_text = get_extracted_text(ocr) + score, _ = flexible_character_accuracy(gt_text, ocr_text) + assert score == pytest.approx(all_line_score) + + @pytest.mark.parametrize( "config,ocr", [