Updated reports and dependencies.

pull/60/head
Benjamin Rosemann 4 years ago
parent 40f23b8482
commit e8ccffb275

@ -1,5 +1,8 @@
import json
import os
from collections import Counter
from functools import partial
from typing import Callable, List, Tuple
import click
from jinja2 import Environment, FileSystemLoader
@ -10,13 +13,28 @@ from uniseg.graphemecluster import grapheme_clusters
from .align import seq_align
from .config import Config
from .extracted_text import ExtractedText
from .metrics import bag_of_chars_accuracy, bag_of_words_accuracy, character_accuracy, \
word_accuracy
from .normalize import words_normalized
from .metrics import (
bag_of_chars_accuracy,
bag_of_words_accuracy,
character_accuracy,
word_accuracy,
)
from .normalize import chars_normalized, words_normalized
from .ocr_files import text
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
def gen_count_report(
gt_text: str, ocr_text: str, split_fun: Callable[[str], Counter]
) -> List[Tuple[str, int, int]]:
gt_counter = Counter(split_fun(gt_text))
ocr_counter = Counter(split_fun(ocr_text))
return [
("".join(key), gt_counter[key], ocr_counter[key])
for key in sorted({*gt_counter.keys(), *ocr_counter.keys()})
]
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]:
gtx = ""
ocrx = ""
@ -36,7 +54,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_)
if css_classes:
return f"<span class=\"{css_classes}\" {html_custom_attrs}>{html_t}</span>"
return f'<span class="{css_classes}" {html_custom_attrs}>{html_t}</span>'
else:
return f"{html_t}"
@ -72,26 +90,30 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
if o is not None:
o_pos += len(o)
return """
<div class="row">
<div class="col-md-6 gt">{}</div>
<div class="col-md-6 ocr">{}</div>
</div>
""".format(
gtx, ocrx
)
return gtx, ocrx
def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results):
char_diff_report = gen_diff_report(
gt_text, ocr_text, css_prefix="c", joiner="", none="·"
)
gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text)
word_diff_report = gen_diff_report(
gt_words, ocr_words, css_prefix="w", joiner=" ", none=""
)
metric_dict = {
"character_accuracy": partial(
gen_diff_report, css_prefix="c", joiner="", none="·"
),
"word_accuracy": lambda gt_text, ocr_text: gen_diff_report(
words_normalized(gt_text),
words_normalized(ocr_text),
css_prefix="w",
joiner=" ",
none="",
),
"bag_of_chars_accuracy": partial(gen_count_report, split_fun=chars_normalized),
"bag_of_words_accuracy": partial(gen_count_report, split_fun=words_normalized),
}
metrics_reports = {}
for metric in metrics_results.keys():
if metric not in metric_dict.keys():
raise ValueError(f"Unknown metric '{metric}'.")
metrics_reports[metric] = metric_dict[metric](gt_text, ocr_text)
env = Environment(
loader=FileSystemLoader(
@ -107,15 +129,14 @@ def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_resu
template.stream(
gt=gt,
ocr=ocr,
char_diff_report=char_diff_report,
word_diff_report=word_diff_report,
metrics_reports=metrics_reports,
metrics_results=metrics_results,
).dump(out_fn)
def generate_json_report(gt, ocr, report_prefix, metrics_results):
json_dict = {"gt": gt, "ocr": ocr}
for result in metrics_results:
for result in metrics_results.values():
json_dict[result.metric] = {
key: value for key, value in result.get_dict().items() if key != "metric"
}
@ -133,7 +154,7 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
gt_text = text(gt, textequiv_level=textequiv_level)
ocr_text = text(ocr, textequiv_level=textequiv_level)
metrics_results = set()
metrics_results = {}
if metrics:
metric_dict = {
"ca": character_accuracy,
@ -147,7 +168,8 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
metric = metric.strip()
if metric not in metric_dict.keys():
raise ValueError(f"Unknown metric '{metric}'.")
metrics_results.add(metric_dict[metric](gt_text, ocr_text))
result = metric_dict[metric](gt_text, ocr_text)
metrics_results[result.metric] = result
generate_json_report(gt, ocr, report_prefix, metrics_results)
generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results)

@ -1,8 +1,6 @@
from collections import Counter
from unicodedata import normalize
from uniseg.graphemecluster import grapheme_clusters
from ..normalize import chars_normalized
from .utils import bag_accuracy, MetricResult, Weights
@ -19,8 +17,8 @@ def bag_of_chars_accuracy(
:param weights: Weights/costs for editing operations.
:return: Class representing the results of this metric.
"""
reference_chars: Counter = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars: Counter = Counter(grapheme_clusters(normalize("NFC", compared)))
reference_chars: Counter = Counter(chars_normalized(reference))
compared_chars: Counter = Counter(chars_normalized(compared))
return bag_accuracy(
reference_chars, compared_chars, weights, bag_of_chars_accuracy.__name__
)

@ -4,18 +4,18 @@
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
<style type="text/css">
{% if metrics_results %}
.gt .diff {
color: green;
color: #198754;
}
.ocr .diff {
color: red;
color: #dc3545;
}
{% else %}
.gt .diff, .ocr .diff {
color: blue;
color: #0d6efd;
}
{% endif %}
.ellipsis {
@ -30,20 +30,20 @@
</head>
<body>
<div class="container">
{{ gt }}<br>
{{ ocr }}
<dl class="row bg-secondary text-white">
<dt class="col-sm-3">Reference</dt>
<dd class="col-sm-9">{{ gt }}</dd>
<dt class="col-sm-3">Compared</dt>
<dd class="col-sm-9">{{ ocr }}</dd>
</dl>
{% if metrics_results %}
<h2>Metrics</h2>
<table class="table">
<table class="table table-hover table-sm">
<caption>Legend: Acc = Accuracy, ER = Error Rate, Ref = Reference, Cmp = Compared</caption>
<thead>
<tr>
<th scope="col" rowspan="2">#</th>
<th scope="col" rowspan="2"><p class="fs-1">Metrics</p></th>
<th scope="col" colspan="2">Results</th>
<th scope="col" colspan="2">Elements</th>
<th scope="col" colspan="2">Calculation</th>
@ -58,7 +58,7 @@
</tr>
</thead>
<tbody>
{% for result in metrics_results %}
{% for result in metrics_results.values() %}
<tr>
<th scope="row">{{ result.metric.replace("_", " ") }}</th>
<td>{{ result.accuracy }}</td>
@ -70,33 +70,54 @@
</tr>
{% endfor %}
</tbody>
<tfoot>
</table>
{% endif %}
<div class="row">
{% for heading, metric in (
("Character differences", "character_accuracy"),
("Word differences", "word_accuracy"),
("Bag of Chars", "bag_of_chars_accuracy"),
("Bag of Words", "bag_of_words_accuracy")) %}
{% if metric in metrics_reports.keys() %}
{% if metric in ("bag_of_chars_accuracy", "bag_of_words_accuracy") %}
<div class="col-6">
<h2>{{ heading }}</h2>
<table class="table table-striped table-hover table-sm">
<thead>
<tr>
<th scope="col">#</th>
<th scope="col">Ref</th>
<th scope="col">Cmp</th>
</tr>
</thead>
<tbody>
{% for (text, count_gt, count_ocr) in metrics_reports[metric] %}
<tr>
<td colspan="7">Legend: Acc. = Accuracy, ER = Error Rate, Ref. = Reference, Cmp. = Compared</td>
<td>{{ text }}</td>
<td>{{ count_gt }}</td>
<td>{{ count_ocr }}</td>
</tr>
</tfoot>
{% endfor %}
</tbody>
</table>
</div>
{% else %}
<h2 class="col-12">{{ heading }}</h2>
<div class="col-6 gt border rounded-start">{{ metrics_reports[metric][0] }}</div>
<div class="col-6 ocr border rounded-end">{{ metrics_reports[metric][1] }}</div>
{% endif %}
<h2>Character differences</h2>
{{ char_diff_report }}
<h2>Word differences</h2>
{{ word_diff_report }}
{% endif %}
{% endfor %}
</div>
</div>
<script src="https://code.jquery.com/jquery-3.3.1.slim.min.js" integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo" crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.7/umd/popper.min.js" integrity="sha384-UO2eT0CpHqdSJQ6hJty5KVphtPhzWj9WO1clHTMGa3JDZwrnQq4sF86dIHNDz0W1" crossorigin="anonymous"></script>
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/js/bootstrap.min.js" integrity="sha384-JjSmVgyd0p3pXB1rRibZUAYoIIy6OrQ6VrjIEaFf/nJGzIxFDsf4x0xIM+B07jRM" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/jquery@3.6.0/dist/jquery.min.js" integrity="sha256-/xUj+3OJU5yExlq6GSYGSHk7tPXikynS7ogEvDej/m4=" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/js/bootstrap.bundle.min.js" integrity="sha384-gtEjrD/SeCtmISkJkNUaaKMoLD0//ElJ19smozuHV6z3Iehds+3Ulb9Bn9Plx0x4" crossorigin="anonymous"></script>
<script>
{% include 'report.html.js' %}
</script>
</body>
</html>

Loading…
Cancel
Save