Reintroduce tooltips in report.

pull/60/head
Benjamin Rosemann 4 years ago
parent 12dcdb81da
commit 9f8f88df1f

Binary file not shown.

Before

Width:  |  Height:  |  Size: 265 KiB

After

Width:  |  Height:  |  Size: 115 KiB

@ -2,13 +2,12 @@ import json
import os
from collections import Counter
from functools import partial
from typing import Callable, List, Tuple
from typing import Any, Callable, Dict, List, Tuple
import click
from jinja2 import Environment, FileSystemLoader
from markupsafe import escape
from ocrd_utils import initLogging
from uniseg.graphemecluster import grapheme_clusters
from .align import seq_align
from .config import Config
@ -20,21 +19,28 @@ from .metrics import (
word_accuracy,
)
from .normalize import chars_normalized, words_normalized
from .ocr_files import text
from .ocr_files import extract
def gen_count_report(
gt_text: str, ocr_text: str, split_fun: Callable[[str], Counter]
gt_text: ExtractedText, ocr_text: ExtractedText, split_fun: Callable[[str], Counter]
) -> List[Tuple[str, int, int]]:
gt_counter = Counter(split_fun(gt_text))
ocr_counter = Counter(split_fun(ocr_text))
gt_counter = Counter(split_fun(gt_text.text))
ocr_counter = Counter(split_fun(ocr_text.text))
return [
("".join(key), gt_counter[key], ocr_counter[key])
for key in sorted({*gt_counter.keys(), *ocr_counter.keys()})
]
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]:
def gen_diff_report(
gt_in: ExtractedText,
ocr_in: ExtractedText,
css_prefix: str = "c",
joiner: str = "",
none: str = "·",
split_fun=chars_normalized,
) -> Tuple[str, str]:
gtx = ""
ocrx = ""
@ -58,15 +64,8 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]:
else:
return f"{html_t}"
if isinstance(gt_in, ExtractedText):
if not isinstance(ocr_in, ExtractedText):
raise TypeError()
# XXX splitting should be done in ExtractedText
gt_things = list(grapheme_clusters(gt_in.text))
ocr_things = list(grapheme_clusters(ocr_in.text))
else:
gt_things = gt_in
ocr_things = ocr_in
gt_things = split_fun(gt_in.text)
ocr_things = split_fun(ocr_in.text)
g_pos = 0
o_pos = 0
@ -76,7 +75,6 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]:
ocr_id = None
if g != o:
css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k)
if isinstance(gt_in, ExtractedText):
gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None
ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None
# Deletions and inserts only produce one id + None, UI must
@ -93,18 +91,29 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]:
return gtx, ocrx
def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results):
def generate_html_report(
gt: str,
ocr: str,
gt_text: ExtractedText,
ocr_text: ExtractedText,
report_prefix: str,
metrics_results: Dict,
):
metric_dict = {
metric_dict: Dict[str, Callable] = {
"character_accuracy": partial(
gen_diff_report, css_prefix="c", joiner="", none="·"
gen_diff_report,
css_prefix="c",
joiner="",
none="·",
split_fun=chars_normalized,
),
"word_accuracy": lambda gt_text, ocr_text: gen_diff_report(
words_normalized(gt_text),
words_normalized(ocr_text),
"word_accuracy": partial(
gen_diff_report,
css_prefix="w",
joiner=" ",
none="",
split_fun=words_normalized,
),
"bag_of_chars_accuracy": partial(gen_count_report, split_fun=chars_normalized),
"bag_of_words_accuracy": partial(gen_count_report, split_fun=words_normalized),
@ -134,8 +143,8 @@ def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_resu
).dump(out_fn)
def generate_json_report(gt, ocr, report_prefix, metrics_results):
json_dict = {"gt": gt, "ocr": ocr}
def generate_json_report(gt: str, ocr: str, report_prefix: str, metrics_results: Dict):
json_dict: Dict[str, Any] = {"gt": gt, "ocr": ocr}
for result in metrics_results.values():
json_dict[result.metric] = {
key: value for key, value in result.get_dict().items() if key != "metric"
@ -153,8 +162,8 @@ def process(
so we keep this undecorated version and use Click on a wrapper.
"""
gt_text = text(gt, textequiv_level=textequiv_level)
ocr_text = text(ocr, textequiv_level=textequiv_level)
gt_text = extract(gt, textequiv_level=textequiv_level)
ocr_text = extract(ocr, textequiv_level=textequiv_level)
metrics_results = {}
if metrics:
@ -170,7 +179,7 @@ def process(
metric = metric.strip()
if metric not in metric_dict.keys():
raise ValueError(f"Unknown metric '{metric}'.")
result = metric_dict[metric](gt_text, ocr_text)
result = metric_dict[metric](gt_text.text, ocr_text.text)
metrics_results[result.metric] = result
generate_json_report(gt, ocr, report_prefix, metrics_results)

@ -147,11 +147,7 @@ def plain_extract(filename):
# XXX hardcoded SBB normalization
def plain_text(filename):
return plain_extract(filename).text
def extract(filename, *, textequiv_level="region"):
def extract(filename, *, textequiv_level="region") -> ExtractedText:
"""Extract the text from the given file.
Supports PAGE, ALTO and falls back to plain text.

@ -5,8 +5,7 @@ import textwrap
import lxml.etree as ET
from .util import working_directory
from ..ocr_files import alto_namespace, alto_text, page_namespace, page_text, \
plain_text, text
from ..ocr_files import alto_namespace, alto_text, page_namespace, page_text, text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
@ -179,6 +178,6 @@ def test_plain(tmp_path):
with open("ocr.txt", "w") as ocrf:
ocrf.write("AAAAB")
result = plain_text("ocr.txt")
result = text("ocr.txt")
expected = "AAAAB"
assert result == expected

Loading…
Cancel
Save