Add tooltips to fca report

pull/47/head
Benjamin Rosemann 5 years ago
parent 53064bf833
commit 750ad00d1b

@ -3,4 +3,8 @@ from .extracted_text import *
from .character_error_rate import * from .character_error_rate import *
from .word_error_rate import * from .word_error_rate import *
from .align import * from .align import *
from .flexible_character_accuracy import flexible_character_accuracy, split_matches from .flexible_character_accuracy import (
flexible_character_accuracy,
split_matches,
Match,
)

@ -14,7 +14,7 @@ from .ocr_files import extract
from .config import Config from .config import Config
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None): def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, matches=None):
gtx = "" gtx = ""
ocrx = "" ocrx = ""
@ -42,7 +42,27 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
else: else:
return "{html_t}".format(html_t=html_t) return "{html_t}".format(html_t=html_t)
if isinstance(gt_in, ExtractedText): ops, ocr_ids = None, None
if matches:
gt_things, ocr_things, ops = split_matches(matches)
# we have to reconstruct the order of the ocr because we mixed it for fca
ocr_lines = [match.ocr for match in matches]
ocr_lines_sorted = sorted(ocr_lines, key=lambda x: x.line + x.start / 10000)
ocr_line_region_id = {}
pos = 0
for ocr_line in ocr_lines_sorted:
if ocr_line.line not in ocr_line_region_id.keys():
ocr_line_region_id[ocr_line.line] = ocr_in.segment_id_for_pos(pos)
pos += ocr_line.length
ocr_ids = {None: None}
pos = 0
for ocr_line in ocr_lines:
for _ in ocr_line.text:
ocr_ids[pos] = ocr_line_region_id[ocr_line.line]
pos += 1
elif isinstance(gt_in, ExtractedText):
if not isinstance(ocr_in, ExtractedText): if not isinstance(ocr_in, ExtractedText):
raise TypeError() raise TypeError()
# XXX splitting should be done in ExtractedText # XXX splitting should be done in ExtractedText
@ -61,10 +81,13 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, ops=None):
if g != o: if g != o:
css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k) css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k)
if isinstance(gt_in, ExtractedText): if isinstance(gt_in, ExtractedText):
gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None
ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None
# Deletions and inserts only produce one id + None, UI must # Deletions and inserts only produce one id + None, UI must
# support this, i.e. display for the one id produced # support this, i.e. display for the one id produced
gt_id = gt_in.segment_id_for_pos(g_pos) if g else None
if ocr_ids:
ocr_id = ocr_ids[o_pos]
else:
ocr_id = ocr_in.segment_id_for_pos(o_pos) if o else None
gtx += joiner + format_thing(g, css_classes, gt_id) gtx += joiner + format_thing(g, css_classes, gt_id)
ocrx += joiner + format_thing(o, css_classes, ocr_id) ocrx += joiner + format_thing(o, css_classes, ocr_id)
@ -111,15 +134,9 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
gt_words, ocr_words, css_prefix="w", joiner=" ", none="" gt_words, ocr_words, css_prefix="w", joiner=" ", none=""
) )
if "fca" in metrics: if "fca" in metrics:
fca, fca_matches = flexible_character_accuracy(gt_text.text, ocr_text.text) fca, fca_matches = flexible_character_accuracy(gt_text, ocr_text)
fca_gt_segments, fca_ocr_segments, ops = split_matches(fca_matches)
fca_diff_report = gen_diff_report( fca_diff_report = gen_diff_report(
fca_gt_segments, gt_text, ocr_text, css_prefix="c", joiner="", none="·", matches=fca_matches
fca_ocr_segments,
css_prefix="c",
joiner="",
none="·",
ops=ops,
) )
def json_float(value): def json_float(value):

@ -17,7 +17,9 @@ from functools import lru_cache, reduce
from itertools import product, takewhile from itertools import product, takewhile
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from . import editops from multimethod import multimethod
from . import editops, ExtractedText
if sys.version_info.minor == 5: if sys.version_info.minor == 5:
from .flexible_character_accuracy_ds_35 import ( from .flexible_character_accuracy_ds_35 import (
@ -35,6 +37,22 @@ else:
) )
@multimethod
def flexible_character_accuracy(
gt: ExtractedText, ocr: ExtractedText
) -> Tuple[float, List[Match]]:
"""Calculate the flexible character accuracy.
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
:param gt: The ground truth text.
:param ocr: The text to compare the ground truth with.
:return: Score between 0 and 1 and match objects.
"""
return flexible_character_accuracy(gt.text, ocr.text)
@multimethod
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]: def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
"""Calculate the flexible character accuracy. """Calculate the flexible character accuracy.
@ -359,7 +377,7 @@ def split_matches(matches: List[Match]) -> Tuple[List[str], List[str], List[List
:param matches: List of match objects. :param matches: List of match objects.
:return: List of ground truth segments, ocr segments and editing operations. :return: List of ground truth segments, ocr segments and editing operations.
""" """
matches = sorted(matches, key=lambda x: x.gt.line + x.gt.start / 10000) matches = sorted(matches, key=lambda m: m.gt.line + m.gt.start / 10000)
line = 0 line = 0
gt, ocr, ops = [], [], [] gt, ocr, ops = [], [], []
for match in matches: for match in matches:
@ -410,4 +428,4 @@ class Part(PartVersionSpecific):
""" """
text = self.text[rel_start:rel_end] text = self.text[rel_start:rel_end]
start = self.start + rel_start start = self.start + rel_start
return Part(text=text, line=self.line, start=start) return Part(**{**self._asdict(), "text": text, "start": start})

@ -10,6 +10,7 @@ DOI: 10.1016/j.patrec.2020.02.003
""" """
import pytest import pytest
from lxml import etree as ET
from ..flexible_character_accuracy import * from ..flexible_character_accuracy import *
@ -101,11 +102,39 @@ def extended_case_to_text(gt, ocr):
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) @pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_score): def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score):
score, _ = flexible_character_accuracy(gt, ocr) score, _ = flexible_character_accuracy(gt, ocr)
assert score == pytest.approx(all_line_score) assert score == pytest.approx(all_line_score)
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_score):
def get_extracted_text(text: str):
xml = '<?xml version="1.0"?>'
ns = "http://schema.primaresearch.org/PAGE/gts/pagecontent/2018-07-15"
textline_tmpl = (
'<TextLine id="l{0}"><TextEquiv><Unicode>{1}'
"</Unicode></TextEquiv></TextLine>"
)
xml_tmpl = '{0}<TextRegion id="0" xmlns="{1}">{2}</TextRegion>'
textlines = [
textline_tmpl.format(i, line) for i, line in enumerate(text.splitlines())
]
xml_text = xml_tmpl.format(xml, ns, "".join(textlines))
root = ET.fromstring(xml_text)
extracted_text = ExtractedText.from_text_segment(
root, {"page": ns}, textequiv_level="line"
)
return extracted_text
gt_text = get_extracted_text(gt)
ocr_text = get_extracted_text(ocr)
score, _ = flexible_character_accuracy(gt_text, ocr_text)
assert score == pytest.approx(all_line_score)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config,ocr", "config,ocr",
[ [

Loading…
Cancel
Save