|
|
|
@ -1,5 +1,8 @@
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
from collections import Counter
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Callable, List, Tuple
|
|
|
|
|
|
|
|
|
|
import click
|
|
|
|
|
from jinja2 import Environment, FileSystemLoader
|
|
|
|
@ -10,13 +13,28 @@ from uniseg.graphemecluster import grapheme_clusters
|
|
|
|
|
from .align import seq_align
|
|
|
|
|
from .config import Config
|
|
|
|
|
from .extracted_text import ExtractedText
|
|
|
|
|
from .metrics import bag_of_chars_accuracy, bag_of_words_accuracy, character_accuracy, \
|
|
|
|
|
word_accuracy
|
|
|
|
|
from .normalize import words_normalized
|
|
|
|
|
from .metrics import (
|
|
|
|
|
bag_of_chars_accuracy,
|
|
|
|
|
bag_of_words_accuracy,
|
|
|
|
|
character_accuracy,
|
|
|
|
|
word_accuracy,
|
|
|
|
|
)
|
|
|
|
|
from .normalize import chars_normalized, words_normalized
|
|
|
|
|
from .ocr_files import text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
|
|
|
|
|
def gen_count_report(
|
|
|
|
|
gt_text: str, ocr_text: str, split_fun: Callable[[str], Counter]
|
|
|
|
|
) -> List[Tuple[str, int, int]]:
|
|
|
|
|
gt_counter = Counter(split_fun(gt_text))
|
|
|
|
|
ocr_counter = Counter(split_fun(ocr_text))
|
|
|
|
|
return [
|
|
|
|
|
("".join(key), gt_counter[key], ocr_counter[key])
|
|
|
|
|
for key in sorted({*gt_counter.keys(), *ocr_counter.keys()})
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]:
|
|
|
|
|
gtx = ""
|
|
|
|
|
ocrx = ""
|
|
|
|
|
|
|
|
|
@ -36,7 +54,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
|
|
|
|
|
html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_)
|
|
|
|
|
|
|
|
|
|
if css_classes:
|
|
|
|
|
return f"<span class=\"{css_classes}\" {html_custom_attrs}>{html_t}</span>"
|
|
|
|
|
return f'<span class="{css_classes}" {html_custom_attrs}>{html_t}</span>'
|
|
|
|
|
else:
|
|
|
|
|
return f"{html_t}"
|
|
|
|
|
|
|
|
|
@ -72,26 +90,30 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
|
|
|
|
|
if o is not None:
|
|
|
|
|
o_pos += len(o)
|
|
|
|
|
|
|
|
|
|
return """
|
|
|
|
|
<div class="row">
|
|
|
|
|
<div class="col-md-6 gt">{}</div>
|
|
|
|
|
<div class="col-md-6 ocr">{}</div>
|
|
|
|
|
</div>
|
|
|
|
|
""".format(
|
|
|
|
|
gtx, ocrx
|
|
|
|
|
)
|
|
|
|
|
return gtx, ocrx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results):
|
|
|
|
|
char_diff_report = gen_diff_report(
|
|
|
|
|
gt_text, ocr_text, css_prefix="c", joiner="", none="·"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
gt_words = words_normalized(gt_text)
|
|
|
|
|
ocr_words = words_normalized(ocr_text)
|
|
|
|
|
word_diff_report = gen_diff_report(
|
|
|
|
|
gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯"
|
|
|
|
|
)
|
|
|
|
|
metric_dict = {
|
|
|
|
|
"character_accuracy": partial(
|
|
|
|
|
gen_diff_report, css_prefix="c", joiner="", none="·"
|
|
|
|
|
),
|
|
|
|
|
"word_accuracy": lambda gt_text, ocr_text: gen_diff_report(
|
|
|
|
|
words_normalized(gt_text),
|
|
|
|
|
words_normalized(ocr_text),
|
|
|
|
|
css_prefix="w",
|
|
|
|
|
joiner=" ",
|
|
|
|
|
none="⋯",
|
|
|
|
|
),
|
|
|
|
|
"bag_of_chars_accuracy": partial(gen_count_report, split_fun=chars_normalized),
|
|
|
|
|
"bag_of_words_accuracy": partial(gen_count_report, split_fun=words_normalized),
|
|
|
|
|
}
|
|
|
|
|
metrics_reports = {}
|
|
|
|
|
for metric in metrics_results.keys():
|
|
|
|
|
if metric not in metric_dict.keys():
|
|
|
|
|
raise ValueError(f"Unknown metric '{metric}'.")
|
|
|
|
|
metrics_reports[metric] = metric_dict[metric](gt_text, ocr_text)
|
|
|
|
|
|
|
|
|
|
env = Environment(
|
|
|
|
|
loader=FileSystemLoader(
|
|
|
|
@ -107,15 +129,14 @@ def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_resu
|
|
|
|
|
template.stream(
|
|
|
|
|
gt=gt,
|
|
|
|
|
ocr=ocr,
|
|
|
|
|
char_diff_report=char_diff_report,
|
|
|
|
|
word_diff_report=word_diff_report,
|
|
|
|
|
metrics_reports=metrics_reports,
|
|
|
|
|
metrics_results=metrics_results,
|
|
|
|
|
).dump(out_fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_json_report(gt, ocr, report_prefix, metrics_results):
|
|
|
|
|
json_dict = {"gt": gt, "ocr": ocr}
|
|
|
|
|
for result in metrics_results:
|
|
|
|
|
for result in metrics_results.values():
|
|
|
|
|
json_dict[result.metric] = {
|
|
|
|
|
key: value for key, value in result.get_dict().items() if key != "metric"
|
|
|
|
|
}
|
|
|
|
@ -133,7 +154,7 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
|
|
|
|
|
gt_text = text(gt, textequiv_level=textequiv_level)
|
|
|
|
|
ocr_text = text(ocr, textequiv_level=textequiv_level)
|
|
|
|
|
|
|
|
|
|
metrics_results = set()
|
|
|
|
|
metrics_results = {}
|
|
|
|
|
if metrics:
|
|
|
|
|
metric_dict = {
|
|
|
|
|
"ca": character_accuracy,
|
|
|
|
@ -147,7 +168,8 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
|
|
|
|
|
metric = metric.strip()
|
|
|
|
|
if metric not in metric_dict.keys():
|
|
|
|
|
raise ValueError(f"Unknown metric '{metric}'.")
|
|
|
|
|
metrics_results.add(metric_dict[metric](gt_text, ocr_text))
|
|
|
|
|
result = metric_dict[metric](gt_text, ocr_text)
|
|
|
|
|
metrics_results[result.metric] = result
|
|
|
|
|
|
|
|
|
|
generate_json_report(gt, ocr, report_prefix, metrics_results)
|
|
|
|
|
generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results)
|
|
|
|
|