mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-07 19:05:13 +02:00
Add tooltips to fca report
This commit is contained in:
parent
53064bf833
commit
750ad00d1b
4 changed files with 85 additions and 17 deletions
|
@ -3,4 +3,8 @@ from .extracted_text import *
|
|||
from .character_error_rate import *
|
||||
from .word_error_rate import *
|
||||
from .align import *
|
||||
from .flexible_character_accuracy import flexible_character_accuracy, split_matches
|
||||
from .flexible_character_accuracy import (
|
||||
flexible_character_accuracy,
|
||||
split_matches,
|
||||
Match,
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@ from .ocr_files import extract
|
|||
from .config import Config
|
||||
|
||||
|
||||
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
|
||||
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, matches=None):
|
||||
gtx = ""
|
||||
ocrx = ""
|
||||
|
||||
|
@ -42,7 +42,27 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
|
|||
else:
|
||||
return "{html_t}".format(html_t=html_t)
|
||||
|
||||
if isinstance(gt_in, ExtractedText):
|
||||
ops, ocr_ids = None, None
|
||||
if matches:
|
||||
gt_things, ocr_things, ops = split_matches(matches)
|
||||
# we have to reconstruct the order of the ocr because we mixed it for fca
|
||||
ocr_lines = [match.ocr for match in matches]
|
||||
ocr_lines_sorted = sorted(ocr_lines, key=lambda x: x.line + x.start / 10000)
|
||||
|
||||
ocr_line_region_id = {}
|
||||
pos = 0
|
||||
for ocr_line in ocr_lines_sorted:
|
||||
if ocr_line.line not in ocr_line_region_id.keys():
|
||||
ocr_line_region_id[ocr_line.line] = ocr_in.segment_id_for_pos(pos)
|
||||
pos += ocr_line.length
|
||||
|
||||
ocr_ids = {None: None}
|
||||
pos = 0
|
||||
for ocr_line in ocr_lines:
|
||||
for _ in ocr_line.text:
|
||||
ocr_ids[pos] = ocr_line_region_id[ocr_line.line]
|
||||
pos += 1
|
||||
elif isinstance(gt_in, ExtractedText):
|
||||
if not isinstance(ocr_in, ExtractedText):
|
||||
raise TypeError()
|
||||
# XXX splitting should be done in ExtractedText
|
||||
|
@ -61,10 +81,13 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
|
|||
if g != o:
|
||||
css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k)
|
||||
if isinstance(gt_in, ExtractedText):
|
||||
gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None
|
||||
ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None
|
||||
# Deletions and inserts only produce one id + None, UI must
|
||||
# support this, i.e. display for the one id produced
|
||||
gt_id = gt_in.segment_id_for_pos(g_pos) if g else None
|
||||
if ocr_ids:
|
||||
ocr_id = ocr_ids[o_pos]
|
||||
else:
|
||||
ocr_id = ocr_in.segment_id_for_pos(o_pos) if o else None
|
||||
|
||||
gtx += joiner + format_thing(g, css_classes, gt_id)
|
||||
ocrx += joiner + format_thing(o, css_classes, ocr_id)
|
||||
|
@ -111,15 +134,9 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
|
|||
gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯"
|
||||
)
|
||||
if "fca" in metrics:
|
||||
fca, fca_matches = flexible_character_accuracy(gt_text.text, ocr_text.text)
|
||||
fca_gt_segments, fca_ocr_segments, ops = split_matches(fca_matches)
|
||||
fca, fca_matches = flexible_character_accuracy(gt_text, ocr_text)
|
||||
fca_diff_report = gen_diff_report(
|
||||
fca_gt_segments,
|
||||
fca_ocr_segments,
|
||||
css_prefix="c",
|
||||
joiner="",
|
||||
none="·",
|
||||
ops=ops,
|
||||
gt_text, ocr_text, css_prefix="c", joiner="", none="·", matches=fca_matches
|
||||
)
|
||||
|
||||
def json_float(value):
|
||||
|
|
|
@ -17,7 +17,9 @@ from functools import lru_cache, reduce
|
|||
from itertools import product, takewhile
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from . import editops
|
||||
from multimethod import multimethod
|
||||
|
||||
from . import editops, ExtractedText
|
||||
|
||||
if sys.version_info.minor == 5:
|
||||
from .flexible_character_accuracy_ds_35 import (
|
||||
|
@ -35,6 +37,22 @@ else:
|
|||
)
|
||||
|
||||
|
||||
@multimethod
|
||||
def flexible_character_accuracy(
|
||||
gt: ExtractedText, ocr: ExtractedText
|
||||
) -> Tuple[float, List[Match]]:
|
||||
"""Calculate the flexible character accuracy.
|
||||
|
||||
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
|
||||
|
||||
:param gt: The ground truth text.
|
||||
:param ocr: The text to compare the ground truth with.
|
||||
:return: Score between 0 and 1 and match objects.
|
||||
"""
|
||||
return flexible_character_accuracy(gt.text, ocr.text)
|
||||
|
||||
|
||||
@multimethod
|
||||
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
|
||||
"""Calculate the flexible character accuracy.
|
||||
|
||||
|
@ -359,7 +377,7 @@ def split_matches(matches: List[Match]) -> Tuple[List[str], List[str], List[List
|
|||
:param matches: List of match objects.
|
||||
:return: List of ground truth segments, ocr segments and editing operations.
|
||||
"""
|
||||
matches = sorted(matches, key=lambda x: x.gt.line + x.gt.start / 10000)
|
||||
matches = sorted(matches, key=lambda m: m.gt.line + m.gt.start / 10000)
|
||||
line = 0
|
||||
gt, ocr, ops = [], [], []
|
||||
for match in matches:
|
||||
|
@ -410,4 +428,4 @@ class Part(PartVersionSpecific):
|
|||
"""
|
||||
text = self.text[rel_start:rel_end]
|
||||
start = self.start + rel_start
|
||||
return Part(text=text, line=self.line, start=start)
|
||||
return Part(**{**self._asdict(), "text": text, "start": start})
|
||||
|
|
|
@ -10,6 +10,7 @@ DOI: 10.1016/j.patrec.2020.02.003
|
|||
"""
|
||||
|
||||
import pytest
|
||||
from lxml import etree as ET
|
||||
|
||||
from ..flexible_character_accuracy import *
|
||||
|
||||
|
@ -101,11 +102,39 @@ def extended_case_to_text(gt, ocr):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
|
||||
def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_score):
|
||||
def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score):
|
||||
score, _ = flexible_character_accuracy(gt, ocr)
|
||||
assert score == pytest.approx(all_line_score)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
|
||||
def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_score):
|
||||
def get_extracted_text(text: str):
|
||||
xml = '<?xml version="1.0"?>'
|
||||
ns = "http://schema.primaresearch.org/PAGE/gts/pagecontent/2018-07-15"
|
||||
|
||||
textline_tmpl = (
|
||||
'<TextLine id="l{0}"><TextEquiv><Unicode>{1}'
|
||||
"</Unicode></TextEquiv></TextLine>"
|
||||
)
|
||||
xml_tmpl = '{0}<TextRegion id="0" xmlns="{1}">{2}</TextRegion>'
|
||||
|
||||
textlines = [
|
||||
textline_tmpl.format(i, line) for i, line in enumerate(text.splitlines())
|
||||
]
|
||||
xml_text = xml_tmpl.format(xml, ns, "".join(textlines))
|
||||
root = ET.fromstring(xml_text)
|
||||
extracted_text = ExtractedText.from_text_segment(
|
||||
root, {"page": ns}, textequiv_level="line"
|
||||
)
|
||||
return extracted_text
|
||||
|
||||
gt_text = get_extracted_text(gt)
|
||||
ocr_text = get_extracted_text(ocr)
|
||||
score, _ = flexible_character_accuracy(gt_text, ocr_text)
|
||||
assert score == pytest.approx(all_line_score)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config,ocr",
|
||||
[
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue