mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-07 19:05:13 +02:00
Add tooltips to fca report
This commit is contained in:
parent
53064bf833
commit
750ad00d1b
4 changed files with 85 additions and 17 deletions
|
@ -3,4 +3,8 @@ from .extracted_text import *
|
||||||
from .character_error_rate import *
|
from .character_error_rate import *
|
||||||
from .word_error_rate import *
|
from .word_error_rate import *
|
||||||
from .align import *
|
from .align import *
|
||||||
from .flexible_character_accuracy import flexible_character_accuracy, split_matches
|
from .flexible_character_accuracy import (
|
||||||
|
flexible_character_accuracy,
|
||||||
|
split_matches,
|
||||||
|
Match,
|
||||||
|
)
|
||||||
|
|
|
@ -14,7 +14,7 @@ from .ocr_files import extract
|
||||||
from .config import Config
|
from .config import Config
|
||||||
|
|
||||||
|
|
||||||
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
|
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, matches=None):
|
||||||
gtx = ""
|
gtx = ""
|
||||||
ocrx = ""
|
ocrx = ""
|
||||||
|
|
||||||
|
@ -42,7 +42,27 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
|
||||||
else:
|
else:
|
||||||
return "{html_t}".format(html_t=html_t)
|
return "{html_t}".format(html_t=html_t)
|
||||||
|
|
||||||
if isinstance(gt_in, ExtractedText):
|
ops, ocr_ids = None, None
|
||||||
|
if matches:
|
||||||
|
gt_things, ocr_things, ops = split_matches(matches)
|
||||||
|
# we have to reconstruct the order of the ocr because we mixed it for fca
|
||||||
|
ocr_lines = [match.ocr for match in matches]
|
||||||
|
ocr_lines_sorted = sorted(ocr_lines, key=lambda x: x.line + x.start / 10000)
|
||||||
|
|
||||||
|
ocr_line_region_id = {}
|
||||||
|
pos = 0
|
||||||
|
for ocr_line in ocr_lines_sorted:
|
||||||
|
if ocr_line.line not in ocr_line_region_id.keys():
|
||||||
|
ocr_line_region_id[ocr_line.line] = ocr_in.segment_id_for_pos(pos)
|
||||||
|
pos += ocr_line.length
|
||||||
|
|
||||||
|
ocr_ids = {None: None}
|
||||||
|
pos = 0
|
||||||
|
for ocr_line in ocr_lines:
|
||||||
|
for _ in ocr_line.text:
|
||||||
|
ocr_ids[pos] = ocr_line_region_id[ocr_line.line]
|
||||||
|
pos += 1
|
||||||
|
elif isinstance(gt_in, ExtractedText):
|
||||||
if not isinstance(ocr_in, ExtractedText):
|
if not isinstance(ocr_in, ExtractedText):
|
||||||
raise TypeError()
|
raise TypeError()
|
||||||
# XXX splitting should be done in ExtractedText
|
# XXX splitting should be done in ExtractedText
|
||||||
|
@ -61,10 +81,13 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
|
||||||
if g != o:
|
if g != o:
|
||||||
css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k)
|
css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k)
|
||||||
if isinstance(gt_in, ExtractedText):
|
if isinstance(gt_in, ExtractedText):
|
||||||
gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None
|
|
||||||
ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None
|
|
||||||
# Deletions and inserts only produce one id + None, UI must
|
# Deletions and inserts only produce one id + None, UI must
|
||||||
# support this, i.e. display for the one id produced
|
# support this, i.e. display for the one id produced
|
||||||
|
gt_id = gt_in.segment_id_for_pos(g_pos) if g else None
|
||||||
|
if ocr_ids:
|
||||||
|
ocr_id = ocr_ids[o_pos]
|
||||||
|
else:
|
||||||
|
ocr_id = ocr_in.segment_id_for_pos(o_pos) if o else None
|
||||||
|
|
||||||
gtx += joiner + format_thing(g, css_classes, gt_id)
|
gtx += joiner + format_thing(g, css_classes, gt_id)
|
||||||
ocrx += joiner + format_thing(o, css_classes, ocr_id)
|
ocrx += joiner + format_thing(o, css_classes, ocr_id)
|
||||||
|
@ -111,15 +134,9 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
|
||||||
gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯"
|
gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯"
|
||||||
)
|
)
|
||||||
if "fca" in metrics:
|
if "fca" in metrics:
|
||||||
fca, fca_matches = flexible_character_accuracy(gt_text.text, ocr_text.text)
|
fca, fca_matches = flexible_character_accuracy(gt_text, ocr_text)
|
||||||
fca_gt_segments, fca_ocr_segments, ops = split_matches(fca_matches)
|
|
||||||
fca_diff_report = gen_diff_report(
|
fca_diff_report = gen_diff_report(
|
||||||
fca_gt_segments,
|
gt_text, ocr_text, css_prefix="c", joiner="", none="·", matches=fca_matches
|
||||||
fca_ocr_segments,
|
|
||||||
css_prefix="c",
|
|
||||||
joiner="",
|
|
||||||
none="·",
|
|
||||||
ops=ops,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def json_float(value):
|
def json_float(value):
|
||||||
|
|
|
@ -17,7 +17,9 @@ from functools import lru_cache, reduce
|
||||||
from itertools import product, takewhile
|
from itertools import product, takewhile
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
from . import editops
|
from multimethod import multimethod
|
||||||
|
|
||||||
|
from . import editops, ExtractedText
|
||||||
|
|
||||||
if sys.version_info.minor == 5:
|
if sys.version_info.minor == 5:
|
||||||
from .flexible_character_accuracy_ds_35 import (
|
from .flexible_character_accuracy_ds_35 import (
|
||||||
|
@ -35,6 +37,22 @@ else:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@multimethod
|
||||||
|
def flexible_character_accuracy(
|
||||||
|
gt: ExtractedText, ocr: ExtractedText
|
||||||
|
) -> Tuple[float, List[Match]]:
|
||||||
|
"""Calculate the flexible character accuracy.
|
||||||
|
|
||||||
|
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
|
||||||
|
|
||||||
|
:param gt: The ground truth text.
|
||||||
|
:param ocr: The text to compare the ground truth with.
|
||||||
|
:return: Score between 0 and 1 and match objects.
|
||||||
|
"""
|
||||||
|
return flexible_character_accuracy(gt.text, ocr.text)
|
||||||
|
|
||||||
|
|
||||||
|
@multimethod
|
||||||
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
|
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
|
||||||
"""Calculate the flexible character accuracy.
|
"""Calculate the flexible character accuracy.
|
||||||
|
|
||||||
|
@ -359,7 +377,7 @@ def split_matches(matches: List[Match]) -> Tuple[List[str], List[str], List[List
|
||||||
:param matches: List of match objects.
|
:param matches: List of match objects.
|
||||||
:return: List of ground truth segments, ocr segments and editing operations.
|
:return: List of ground truth segments, ocr segments and editing operations.
|
||||||
"""
|
"""
|
||||||
matches = sorted(matches, key=lambda x: x.gt.line + x.gt.start / 10000)
|
matches = sorted(matches, key=lambda m: m.gt.line + m.gt.start / 10000)
|
||||||
line = 0
|
line = 0
|
||||||
gt, ocr, ops = [], [], []
|
gt, ocr, ops = [], [], []
|
||||||
for match in matches:
|
for match in matches:
|
||||||
|
@ -410,4 +428,4 @@ class Part(PartVersionSpecific):
|
||||||
"""
|
"""
|
||||||
text = self.text[rel_start:rel_end]
|
text = self.text[rel_start:rel_end]
|
||||||
start = self.start + rel_start
|
start = self.start + rel_start
|
||||||
return Part(text=text, line=self.line, start=start)
|
return Part(**{**self._asdict(), "text": text, "start": start})
|
||||||
|
|
|
@ -10,6 +10,7 @@ DOI: 10.1016/j.patrec.2020.02.003
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from lxml import etree as ET
|
||||||
|
|
||||||
from ..flexible_character_accuracy import *
|
from ..flexible_character_accuracy import *
|
||||||
|
|
||||||
|
@ -101,11 +102,39 @@ def extended_case_to_text(gt, ocr):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
|
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
|
||||||
def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_score):
|
def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score):
|
||||||
score, _ = flexible_character_accuracy(gt, ocr)
|
score, _ = flexible_character_accuracy(gt, ocr)
|
||||||
assert score == pytest.approx(all_line_score)
|
assert score == pytest.approx(all_line_score)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
|
||||||
|
def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_score):
|
||||||
|
def get_extracted_text(text: str):
|
||||||
|
xml = '<?xml version="1.0"?>'
|
||||||
|
ns = "http://schema.primaresearch.org/PAGE/gts/pagecontent/2018-07-15"
|
||||||
|
|
||||||
|
textline_tmpl = (
|
||||||
|
'<TextLine id="l{0}"><TextEquiv><Unicode>{1}'
|
||||||
|
"</Unicode></TextEquiv></TextLine>"
|
||||||
|
)
|
||||||
|
xml_tmpl = '{0}<TextRegion id="0" xmlns="{1}">{2}</TextRegion>'
|
||||||
|
|
||||||
|
textlines = [
|
||||||
|
textline_tmpl.format(i, line) for i, line in enumerate(text.splitlines())
|
||||||
|
]
|
||||||
|
xml_text = xml_tmpl.format(xml, ns, "".join(textlines))
|
||||||
|
root = ET.fromstring(xml_text)
|
||||||
|
extracted_text = ExtractedText.from_text_segment(
|
||||||
|
root, {"page": ns}, textequiv_level="line"
|
||||||
|
)
|
||||||
|
return extracted_text
|
||||||
|
|
||||||
|
gt_text = get_extracted_text(gt)
|
||||||
|
ocr_text = get_extracted_text(ocr)
|
||||||
|
score, _ = flexible_character_accuracy(gt_text, ocr_text)
|
||||||
|
assert score == pytest.approx(all_line_score)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"config,ocr",
|
"config,ocr",
|
||||||
[
|
[
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue