diff --git a/qurator/dinglehopper/__init__.py b/qurator/dinglehopper/__init__.py
index fd309dc..dc45a8f 100644
--- a/qurator/dinglehopper/__init__.py
+++ b/qurator/dinglehopper/__init__.py
@@ -3,4 +3,8 @@ from .extracted_text import *
from .character_error_rate import *
from .word_error_rate import *
from .align import *
-from .flexible_character_accuracy import flexible_character_accuracy, split_matches
+from .flexible_character_accuracy import (
+ flexible_character_accuracy,
+ split_matches,
+ Match,
+)
diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py
index b717618..46fc0b0 100644
--- a/qurator/dinglehopper/cli.py
+++ b/qurator/dinglehopper/cli.py
@@ -14,7 +14,7 @@ from .ocr_files import extract
from .config import Config
-def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
+def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, matches=None):
gtx = ""
ocrx = ""
@@ -42,7 +42,27 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
else:
return "{html_t}".format(html_t=html_t)
- if isinstance(gt_in, ExtractedText):
+ ops, ocr_ids = None, None
+ if matches:
+ gt_things, ocr_things, ops = split_matches(matches)
+ # we have to reconstruct the order of the ocr because we mixed it for fca
+ ocr_lines = [match.ocr for match in matches]
+ ocr_lines_sorted = sorted(ocr_lines, key=lambda x: x.line + x.start / 10000)
+
+ ocr_line_region_id = {}
+ pos = 0
+ for ocr_line in ocr_lines_sorted:
+ if ocr_line.line not in ocr_line_region_id.keys():
+ ocr_line_region_id[ocr_line.line] = ocr_in.segment_id_for_pos(pos)
+ pos += ocr_line.length
+
+ ocr_ids = {None: None}
+ pos = 0
+ for ocr_line in ocr_lines:
+ for _ in ocr_line.text:
+ ocr_ids[pos] = ocr_line_region_id[ocr_line.line]
+ pos += 1
+ elif isinstance(gt_in, ExtractedText):
if not isinstance(ocr_in, ExtractedText):
raise TypeError()
# XXX splitting should be done in ExtractedText
@@ -61,10 +81,13 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
if g != o:
css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k)
if isinstance(gt_in, ExtractedText):
- gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None
- ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None
# Deletions and inserts only produce one id + None, UI must
# support this, i.e. display for the one id produced
+ gt_id = gt_in.segment_id_for_pos(g_pos) if g else None
+ if ocr_ids:
+ ocr_id = ocr_ids[o_pos]
+ else:
+ ocr_id = ocr_in.segment_id_for_pos(o_pos) if o else None
gtx += joiner + format_thing(g, css_classes, gt_id)
ocrx += joiner + format_thing(o, css_classes, ocr_id)
@@ -111,15 +134,9 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯"
)
if "fca" in metrics:
- fca, fca_matches = flexible_character_accuracy(gt_text.text, ocr_text.text)
- fca_gt_segments, fca_ocr_segments, ops = split_matches(fca_matches)
+ fca, fca_matches = flexible_character_accuracy(gt_text, ocr_text)
fca_diff_report = gen_diff_report(
- fca_gt_segments,
- fca_ocr_segments,
- css_prefix="c",
- joiner="",
- none="·",
- ops=ops,
+ gt_text, ocr_text, css_prefix="c", joiner="", none="·", matches=fca_matches
)
def json_float(value):
diff --git a/qurator/dinglehopper/flexible_character_accuracy.py b/qurator/dinglehopper/flexible_character_accuracy.py
index 7865dd1..349384c 100644
--- a/qurator/dinglehopper/flexible_character_accuracy.py
+++ b/qurator/dinglehopper/flexible_character_accuracy.py
@@ -17,7 +17,9 @@ from functools import lru_cache, reduce
from itertools import product, takewhile
from typing import List, Tuple, Optional
-from . import editops
+from multimethod import multimethod
+
+from . import editops, ExtractedText
if sys.version_info.minor == 5:
from .flexible_character_accuracy_ds_35 import (
@@ -35,6 +37,22 @@ else:
)
+@multimethod
+def flexible_character_accuracy(
+ gt: ExtractedText, ocr: ExtractedText
+) -> Tuple[float, List[Match]]:
+ """Calculate the flexible character accuracy.
+
+ Reference: contains steps 1-7 of the flexible character accuracy algorithm.
+
+ :param gt: The ground truth text.
+ :param ocr: The text to compare the ground truth with.
+ :return: Score between 0 and 1 and match objects.
+ """
+ return flexible_character_accuracy(gt.text, ocr.text)
+
+
+@multimethod
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
"""Calculate the flexible character accuracy.
@@ -359,7 +377,7 @@ def split_matches(matches: List[Match]) -> Tuple[List[str], List[str], List[List
:param matches: List of match objects.
:return: List of ground truth segments, ocr segments and editing operations.
"""
- matches = sorted(matches, key=lambda x: x.gt.line + x.gt.start / 10000)
+ matches = sorted(matches, key=lambda m: m.gt.line + m.gt.start / 10000)
line = 0
gt, ocr, ops = [], [], []
for match in matches:
@@ -410,4 +428,4 @@ class Part(PartVersionSpecific):
"""
text = self.text[rel_start:rel_end]
start = self.start + rel_start
- return Part(text=text, line=self.line, start=start)
+ return Part(**{**self._asdict(), "text": text, "start": start})
diff --git a/qurator/dinglehopper/tests/test_flexible_character_accuracy.py b/qurator/dinglehopper/tests/test_flexible_character_accuracy.py
index 2f6d702..3ade597 100644
--- a/qurator/dinglehopper/tests/test_flexible_character_accuracy.py
+++ b/qurator/dinglehopper/tests/test_flexible_character_accuracy.py
@@ -10,6 +10,7 @@ DOI: 10.1016/j.patrec.2020.02.003
"""
import pytest
+from lxml import etree as ET
from ..flexible_character_accuracy import *
@@ -101,11 +102,39 @@ def extended_case_to_text(gt, ocr):
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
-def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_score):
+def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score):
score, _ = flexible_character_accuracy(gt, ocr)
assert score == pytest.approx(all_line_score)
+@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
+def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_score):
+ def get_extracted_text(text: str):
+ xml = ''
+ ns = "http://schema.primaresearch.org/PAGE/gts/pagecontent/2018-07-15"
+
+ textline_tmpl = (
+ '{1}'
+ ""
+ )
+ xml_tmpl = '{0}{2}'
+
+ textlines = [
+ textline_tmpl.format(i, line) for i, line in enumerate(text.splitlines())
+ ]
+ xml_text = xml_tmpl.format(xml, ns, "".join(textlines))
+ root = ET.fromstring(xml_text)
+ extracted_text = ExtractedText.from_text_segment(
+ root, {"page": ns}, textequiv_level="line"
+ )
+ return extracted_text
+
+ gt_text = get_extracted_text(gt)
+ ocr_text = get_extracted_text(ocr)
+ score, _ = flexible_character_accuracy(gt_text, ocr_text)
+ assert score == pytest.approx(all_line_score)
+
+
@pytest.mark.parametrize(
"config,ocr",
[