🎨 Reformat using Black

pull/90/head
Mike Gerber 1 year ago
parent d50d624554
commit bea56117ae

@ -70,7 +70,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, differences=False):
# support this, i.e. display for the one id produced
if differences:
found_differences.append(f'{g} :: {o}')
found_differences.append(f"{g} :: {o}")
gtx += joiner + format_thing(g, css_classes, gt_id)
ocrx += joiner + format_thing(o, css_classes, ocr_id)
@ -82,14 +82,17 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, differences=False):
found_differences = dict(Counter(elem for elem in found_differences))
return """
return (
"""
<div class="row">
<div class="col-md-6 gt">{}</div>
<div class="col-md-6 ocr">{}</div>
</div>
""".format(
gtx, ocrx
), found_differences
),
found_differences,
)
def json_float(value):
@ -105,8 +108,16 @@ def json_float(value):
return str(value)
def process(gt, ocr, report_prefix, reports_folder='.', *, metrics=True,
differences=False, textequiv_level="region"):
def process(
gt,
ocr,
report_prefix,
reports_folder=".",
*,
metrics=True,
differences=False,
textequiv_level="region",
):
"""Check OCR result against GT.
The @click decorators change the signature of the decorated functions, so we keep this undecorated version and use
@ -119,15 +130,19 @@ def process(gt, ocr, report_prefix, reports_folder='.', *, metrics=True,
cer, n_characters = character_error_rate_n(gt_text, ocr_text)
wer, n_words = word_error_rate_n(gt_text, ocr_text)
char_diff_report, diff_c = gen_diff_report(gt_text, ocr_text, css_prefix="c",
joiner="",
none="·", differences=differences)
char_diff_report, diff_c = gen_diff_report(
gt_text, ocr_text, css_prefix="c", joiner="", none="·", differences=differences
)
gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text)
word_diff_report, diff_w = gen_diff_report(
gt_words, ocr_words, css_prefix="w", joiner=" ", none="",
differences=differences
gt_words,
ocr_words,
css_prefix="w",
joiner=" ",
none="",
differences=differences,
)
env = Environment(
@ -162,19 +177,23 @@ def process(gt, ocr, report_prefix, reports_folder='.', *, metrics=True,
).dump(out_fn)
def process_dir(gt, ocr, report_prefix, reports_folder, metrics, differences,
textequiv_level):
def process_dir(
gt, ocr, report_prefix, reports_folder, metrics, differences, textequiv_level
):
for gt_file in os.listdir(gt):
gt_file_path = os.path.join(gt, gt_file)
ocr_file_path = os.path.join(ocr, gt_file)
if os.path.isfile(gt_file_path) and os.path.isfile(ocr_file_path):
process(gt_file_path, ocr_file_path,
process(
gt_file_path,
ocr_file_path,
f"{gt_file}-{report_prefix}",
reports_folder=reports_folder,
metrics=metrics,
differences=differences,
textequiv_level=textequiv_level)
textequiv_level=textequiv_level,
)
else:
print("Skipping {0} and {1}".format(gt_file_path, ocr_file_path))
@ -190,7 +209,7 @@ def process_dir(gt, ocr, report_prefix, reports_folder, metrics, differences,
@click.option(
"--differences",
default=False,
help="Enable reporting character and word level differences"
help="Enable reporting character and word level differences",
)
@click.option(
"--textequiv-level",
@ -199,8 +218,16 @@ def process_dir(gt, ocr, report_prefix, reports_folder, metrics, differences,
metavar="LEVEL",
)
@click.option("--progress", default=False, is_flag=True, help="Show progress bar")
def main(gt, ocr, report_prefix, reports_folder, metrics, differences, textequiv_level,
progress):
def main(
gt,
ocr,
report_prefix,
reports_folder,
metrics,
differences,
textequiv_level,
progress,
):
"""
Compare the PAGE/ALTO/text document GT against the document OCR.
@ -228,11 +255,25 @@ def main(gt, ocr, report_prefix, reports_folder, metrics, differences, textequiv
"OCR must be a directory if GT is a directory", param_hint="ocr"
)
else:
process_dir(gt, ocr, report_prefix, reports_folder, metrics,
differences, textequiv_level)
process_dir(
gt,
ocr,
report_prefix,
reports_folder,
metrics,
differences,
textequiv_level,
)
else:
process(gt, ocr, report_prefix, reports_folder, metrics=metrics,
differences=differences, textequiv_level=textequiv_level)
process(
gt,
ocr,
report_prefix,
reports_folder,
metrics=metrics,
differences=differences,
textequiv_level=textequiv_level,
)
if __name__ == "__main__":

@ -1,4 +1,3 @@
import click
from ocrd_utils import initLogging

@ -23,7 +23,8 @@ def process(reports_folder, occurrences_threshold=1):
if "cer" not in report_data or "wer" not in report_data:
click.echo(
f"Skipping {report} because it does not contain CER and WER")
f"Skipping {report} because it does not contain CER and WER"
)
continue
cer = report_data["cer"]
@ -60,7 +61,7 @@ def process(reports_folder, occurrences_threshold=1):
for report_suffix in (".html", ".json"):
template_fn = "summary" + report_suffix + ".j2"
out_fn = os.path.join(reports_folder, 'summary' + report_suffix)
out_fn = os.path.join(reports_folder, "summary" + report_suffix)
template = env.get_template(template_fn)
template.stream(
num_reports=len(cer_list),
@ -73,14 +74,13 @@ def process(reports_folder, occurrences_threshold=1):
@click.command()
@click.argument("reports_folder",
type=click.Path(exists=True),
default="./reports"
)
@click.option("--occurrences-threshold",
@click.argument("reports_folder", type=click.Path(exists=True), default="./reports")
@click.option(
"--occurrences-threshold",
type=int,
default=1,
help="Only show differences that occur at least this many times.")
help="Only show differences that occur at least this many times.",
)
def main(reports_folder, occurrences_threshold):
"""
Summarize the results from multiple reports generated earlier by dinglehopper.

@ -16,10 +16,15 @@ def test_cli_directory(tmp_path):
"""
initLogging()
process_dir(os.path.join(data_dir, "directory-test", "gt"),
process_dir(
os.path.join(data_dir, "directory-test", "gt"),
os.path.join(data_dir, "directory-test", "ocr"),
"report", str(tmp_path / "reports"), False, True,
"line")
"report",
str(tmp_path / "reports"),
False,
True,
"line",
)
assert os.path.exists(tmp_path / "reports/1.xml-report.json")
assert os.path.exists(tmp_path / "reports/1.xml-report.html")
@ -35,9 +40,14 @@ def test_cli_fail_without_gt(tmp_path):
"""
initLogging()
process_dir(os.path.join(data_dir, "directory-test", "gt"),
process_dir(
os.path.join(data_dir, "directory-test", "gt"),
os.path.join(data_dir, "directory-test", "ocr"),
"report", str(tmp_path / "reports"), False, True,
"line")
"report",
str(tmp_path / "reports"),
False,
True,
"line",
)
assert len(os.listdir(tmp_path / "reports")) == 2 * 2

@ -15,15 +15,23 @@ def test_cli_differences(tmp_path):
the differences found between the GT and OCR text"""
initLogging()
process(os.path.join(data_dir, "test-gt.page2018.xml"),
process(
os.path.join(data_dir, "test-gt.page2018.xml"),
os.path.join(data_dir, "test-fake-ocr.page2018.xml"),
"report", tmp_path, differences=True)
"report",
tmp_path,
differences=True,
)
assert os.path.exists(tmp_path / "report.json")
with open(tmp_path / "report.json", "r") as jsonf:
j = json.load(jsonf)
assert j["differences"] == {"character_level": {'n :: m': 1, 'ſ :: f': 1},
"word_level": {'Augenblick :: Augemblick': 1,
'Verſprochene :: Verfprochene': 1}}
assert j["differences"] == {
"character_level": {"n :: m": 1, "ſ :: f": 1},
"word_level": {
"Augenblick :: Augemblick": 1,
"Verſprochene :: Verfprochene": 1,
},
}

@ -18,16 +18,22 @@ def create_summaries(tmp_path):
reports_dirname = tmp_path / "reports"
reports_dirname.mkdir()
report1 = {"cer": 0.05, "wer": 0.15,
report1 = {
"cer": 0.05,
"wer": 0.15,
"differences": {
"character_level": {"a": 10, "b": 20},
"word_level": {"c": 30, "d": 40}
}}
report2 = {"cer": 0.10, "wer": 0.20,
"word_level": {"c": 30, "d": 40},
},
}
report2 = {
"cer": 0.10,
"wer": 0.20,
"differences": {
"character_level": {"a": 20, "b": 30},
"word_level": {"c": 40, "d": 50}
}}
"word_level": {"c": 40, "d": 50},
},
}
with open(os.path.join(reports_dirname, "report1.json"), "w") as f:
json.dump(report1, f)
@ -47,7 +53,6 @@ def test_cli_summarize_json(tmp_path, create_summaries):
with open(os.path.join(reports_dirname, "summary.json"), "r") as f:
summary_data = json.load(f)
assert summary_data["num_reports"] == 2
assert summary_data["cer_avg"] == expected_cer_avg
assert summary_data["wer_avg"] == expected_wer_avg
@ -83,11 +88,13 @@ def test_cli_summarize_html_skip_invalid(tmp_path, create_summaries):
reports_dirname = create_summaries
# This third report has no WER value and should not be included in the summary
report3 = {"cer": 0.10,
report3 = {
"cer": 0.10,
"differences": {
"character_level": {"a": 20, "b": 30},
"word_level": {"c": 40, "d": 50}
}}
"word_level": {"c": 40, "d": 50},
},
}
with open(os.path.join(reports_dirname, "report3-missing-wer.json"), "w") as f:
json.dump(report3, f)

Loading…
Cancel
Save