Updated reports and dependencies.

pull/60/head
Benjamin Rosemann 4 years ago
parent 40f23b8482
commit e8ccffb275

@ -1,5 +1,8 @@
import json import json
import os import os
from collections import Counter
from functools import partial
from typing import Callable, List, Tuple
import click import click
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
@ -10,13 +13,28 @@ from uniseg.graphemecluster import grapheme_clusters
from .align import seq_align from .align import seq_align
from .config import Config from .config import Config
from .extracted_text import ExtractedText from .extracted_text import ExtractedText
from .metrics import bag_of_chars_accuracy, bag_of_words_accuracy, character_accuracy, \ from .metrics import (
word_accuracy bag_of_chars_accuracy,
from .normalize import words_normalized bag_of_words_accuracy,
character_accuracy,
word_accuracy,
)
from .normalize import chars_normalized, words_normalized
from .ocr_files import text from .ocr_files import text
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): def gen_count_report(
gt_text: str, ocr_text: str, split_fun: Callable[[str], Counter]
) -> List[Tuple[str, int, int]]:
gt_counter = Counter(split_fun(gt_text))
ocr_counter = Counter(split_fun(ocr_text))
return [
("".join(key), gt_counter[key], ocr_counter[key])
for key in sorted({*gt_counter.keys(), *ocr_counter.keys()})
]
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none) -> Tuple[str, str]:
gtx = "" gtx = ""
ocrx = "" ocrx = ""
@ -36,7 +54,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_) html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_)
if css_classes: if css_classes:
return f"<span class=\"{css_classes}\" {html_custom_attrs}>{html_t}</span>" return f'<span class="{css_classes}" {html_custom_attrs}>{html_t}</span>'
else: else:
return f"{html_t}" return f"{html_t}"
@ -72,26 +90,30 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
if o is not None: if o is not None:
o_pos += len(o) o_pos += len(o)
return """ return gtx, ocrx
<div class="row">
<div class="col-md-6 gt">{}</div>
<div class="col-md-6 ocr">{}</div>
</div>
""".format(
gtx, ocrx
)
def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results): def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results):
char_diff_report = gen_diff_report(
gt_text, ocr_text, css_prefix="c", joiner="", none="·"
)
gt_words = words_normalized(gt_text) metric_dict = {
ocr_words = words_normalized(ocr_text) "character_accuracy": partial(
word_diff_report = gen_diff_report( gen_diff_report, css_prefix="c", joiner="", none="·"
gt_words, ocr_words, css_prefix="w", joiner=" ", none="" ),
) "word_accuracy": lambda gt_text, ocr_text: gen_diff_report(
words_normalized(gt_text),
words_normalized(ocr_text),
css_prefix="w",
joiner=" ",
none="",
),
"bag_of_chars_accuracy": partial(gen_count_report, split_fun=chars_normalized),
"bag_of_words_accuracy": partial(gen_count_report, split_fun=words_normalized),
}
metrics_reports = {}
for metric in metrics_results.keys():
if metric not in metric_dict.keys():
raise ValueError(f"Unknown metric '{metric}'.")
metrics_reports[metric] = metric_dict[metric](gt_text, ocr_text)
env = Environment( env = Environment(
loader=FileSystemLoader( loader=FileSystemLoader(
@ -107,15 +129,14 @@ def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_resu
template.stream( template.stream(
gt=gt, gt=gt,
ocr=ocr, ocr=ocr,
char_diff_report=char_diff_report, metrics_reports=metrics_reports,
word_diff_report=word_diff_report,
metrics_results=metrics_results, metrics_results=metrics_results,
).dump(out_fn) ).dump(out_fn)
def generate_json_report(gt, ocr, report_prefix, metrics_results): def generate_json_report(gt, ocr, report_prefix, metrics_results):
json_dict = {"gt": gt, "ocr": ocr} json_dict = {"gt": gt, "ocr": ocr}
for result in metrics_results: for result in metrics_results.values():
json_dict[result.metric] = { json_dict[result.metric] = {
key: value for key, value in result.get_dict().items() if key != "metric" key: value for key, value in result.get_dict().items() if key != "metric"
} }
@ -133,7 +154,7 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
gt_text = text(gt, textequiv_level=textequiv_level) gt_text = text(gt, textequiv_level=textequiv_level)
ocr_text = text(ocr, textequiv_level=textequiv_level) ocr_text = text(ocr, textequiv_level=textequiv_level)
metrics_results = set() metrics_results = {}
if metrics: if metrics:
metric_dict = { metric_dict = {
"ca": character_accuracy, "ca": character_accuracy,
@ -147,7 +168,8 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
metric = metric.strip() metric = metric.strip()
if metric not in metric_dict.keys(): if metric not in metric_dict.keys():
raise ValueError(f"Unknown metric '{metric}'.") raise ValueError(f"Unknown metric '{metric}'.")
metrics_results.add(metric_dict[metric](gt_text, ocr_text)) result = metric_dict[metric](gt_text, ocr_text)
metrics_results[result.metric] = result
generate_json_report(gt, ocr, report_prefix, metrics_results) generate_json_report(gt, ocr, report_prefix, metrics_results)
generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results) generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results)

@ -1,8 +1,6 @@
from collections import Counter from collections import Counter
from unicodedata import normalize
from uniseg.graphemecluster import grapheme_clusters
from ..normalize import chars_normalized
from .utils import bag_accuracy, MetricResult, Weights from .utils import bag_accuracy, MetricResult, Weights
@ -19,8 +17,8 @@ def bag_of_chars_accuracy(
:param weights: Weights/costs for editing operations. :param weights: Weights/costs for editing operations.
:return: Class representing the results of this metric. :return: Class representing the results of this metric.
""" """
reference_chars: Counter = Counter(grapheme_clusters(normalize("NFC", reference))) reference_chars: Counter = Counter(chars_normalized(reference))
compared_chars: Counter = Counter(grapheme_clusters(normalize("NFC", compared))) compared_chars: Counter = Counter(chars_normalized(compared))
return bag_accuracy( return bag_accuracy(
reference_chars, compared_chars, weights, bag_of_chars_accuracy.__name__ reference_chars, compared_chars, weights, bag_of_chars_accuracy.__name__
) )

@ -4,18 +4,18 @@
<meta charset="utf-8"> <meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"> <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous"> <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
<style type="text/css"> <style type="text/css">
{% if metrics_results %} {% if metrics_results %}
.gt .diff { .gt .diff {
color: green; color: #198754;
} }
.ocr .diff { .ocr .diff {
color: red; color: #dc3545;
} }
{% else %} {% else %}
.gt .diff, .ocr .diff { .gt .diff, .ocr .diff {
color: blue; color: #0d6efd;
} }
{% endif %} {% endif %}
.ellipsis { .ellipsis {
@ -30,73 +30,94 @@
</head> </head>
<body> <body>
<div class="container"> <div class="container">
<dl class="row bg-secondary text-white">
{{ gt }}<br> <dt class="col-sm-3">Reference</dt>
{{ ocr }} <dd class="col-sm-9">{{ gt }}</dd>
<dt class="col-sm-3">Compared</dt>
<dd class="col-sm-9">{{ ocr }}</dd>
</dl>
{% if metrics_results %} {% if metrics_results %}
<h2>Metrics</h2> <table class="table table-hover table-sm">
<table class="table"> <caption>Legend: Acc = Accuracy, ER = Error Rate, Ref = Reference, Cmp = Compared</caption>
<thead> <thead>
<tr> <tr>
<th scope="col" rowspan="2">#</th> <th scope="col" rowspan="2"><p class="fs-1">Metrics</p></th>
<th scope="col" colspan="2">Results</th> <th scope="col" colspan="2">Results</th>
<th scope="col" colspan="2">Elements</th> <th scope="col" colspan="2">Elements</th>
<th scope="col" colspan="2">Calculation</th> <th scope="col" colspan="2">Calculation</th>
</tr> </tr>
<tr> <tr>
<th scope="col">Acc</th> <th scope="col">Acc</th>
<th scope="col">ER</th> <th scope="col">ER</th>
<th scope="col">Ref</th> <th scope="col">Ref</th>
<th scope="col">Cmp</th> <th scope="col">Cmp</th>
<th scope="col">Errors</th> <th scope="col">Errors</th>
<th scope="col">Weights</th> <th scope="col">Weights</th>
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
{% for result in metrics_results %} {% for result in metrics_results.values() %}
<tr> <tr>
<th scope="row">{{ result.metric.replace("_", " ") }}</th> <th scope="row">{{ result.metric.replace("_", " ") }}</th>
<td>{{ result.accuracy }}</td> <td>{{ result.accuracy }}</td>
<td>{{ result.error_rate }}</td> <td>{{ result.error_rate }}</td>
<td>{{ result.reference_elements }}</td> <td>{{ result.reference_elements }}</td>
<td>{{ result.compared_elements }}</td> <td>{{ result.compared_elements }}</td>
<td>{{ result.weighted_errors }}</td> <td>{{ result.weighted_errors }}</td>
<td>d={{ result.weights.deletes }},i={{ result.weights.inserts }},r={{ result.weights.replacements }}</td> <td>d={{ result.weights.deletes }},i={{ result.weights.inserts }},r={{ result.weights.replacements }}</td>
</tr> </tr>
{% endfor %} {% endfor %}
</tbody> </tbody>
<tfoot> </table>
<tr>
<td colspan="7">Legend: Acc. = Accuracy, ER = Error Rate, Ref. = Reference, Cmp. = Compared</td>
</tr>
</tfoot>
</table>
{% endif %} {% endif %}
<h2>Character differences</h2> <div class="row">
{{ char_diff_report }} {% for heading, metric in (
("Character differences", "character_accuracy"),
<h2>Word differences</h2> ("Word differences", "word_accuracy"),
{{ word_diff_report }} ("Bag of Chars", "bag_of_chars_accuracy"),
("Bag of Words", "bag_of_words_accuracy")) %}
{% if metric in metrics_reports.keys() %}
{% if metric in ("bag_of_chars_accuracy", "bag_of_words_accuracy") %}
<div class="col-6">
<h2>{{ heading }}</h2>
<table class="table table-striped table-hover table-sm">
<thead>
<tr>
<th scope="col">#</th>
<th scope="col">Ref</th>
<th scope="col">Cmp</th>
</tr>
</thead>
<tbody>
{% for (text, count_gt, count_ocr) in metrics_reports[metric] %}
<tr>
<td>{{ text }}</td>
<td>{{ count_gt }}</td>
<td>{{ count_ocr }}</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% else %}
<h2 class="col-12">{{ heading }}</h2>
<div class="col-6 gt border rounded-start">{{ metrics_reports[metric][0] }}</div>
<div class="col-6 ocr border rounded-end">{{ metrics_reports[metric][1] }}</div>
{% endif %}
{% endif %}
{% endfor %}
</div>
</div> </div>
<script src="https://cdn.jsdelivr.net/npm/jquery@3.6.0/dist/jquery.min.js" integrity="sha256-/xUj+3OJU5yExlq6GSYGSHk7tPXikynS7ogEvDej/m4=" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/js/bootstrap.bundle.min.js" integrity="sha384-gtEjrD/SeCtmISkJkNUaaKMoLD0//ElJ19smozuHV6z3Iehds+3Ulb9Bn9Plx0x4" crossorigin="anonymous"></script>
<script src="https://code.jquery.com/jquery-3.3.1.slim.min.js" integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo" crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.7/umd/popper.min.js" integrity="sha384-UO2eT0CpHqdSJQ6hJty5KVphtPhzWj9WO1clHTMGa3JDZwrnQq4sF86dIHNDz0W1" crossorigin="anonymous"></script>
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/js/bootstrap.min.js" integrity="sha384-JjSmVgyd0p3pXB1rRibZUAYoIIy6OrQ6VrjIEaFf/nJGzIxFDsf4x0xIM+B07jRM" crossorigin="anonymous"></script>
<script> <script>
{% include 'report.html.js' %} {% include 'report.html.js' %}
</script> </script>
</body> </body>
</html> </html>

Loading…
Cancel
Save