From a44a3d4bf2cdd53b9505756fe57b3cb3228113b0 Mon Sep 17 00:00:00 2001 From: Benjamin Rosemann Date: Fri, 11 Jun 2021 15:33:13 +0200 Subject: [PATCH] Error handling --- qurator/dinglehopper/cli.py | 12 ++++++------ qurator/dinglehopper/tests/test_integ_cli.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 qurator/dinglehopper/tests/test_integ_cli.py diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py index 22ba9f5..8452fa8 100644 --- a/qurator/dinglehopper/cli.py +++ b/qurator/dinglehopper/cli.py @@ -123,7 +123,6 @@ def generate_json_report(gt, ocr, report_prefix, metrics_results): json_dict[result.metric] = { key: value for key, value in result.get_dict().items() if key != "metric" } - print(json_dict) with open(f"{report_prefix}.json", "w") as fp: json.dump(json_dict, fp) @@ -149,12 +148,13 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio "bow": bag_of_words_accuracy, } for metric in metrics.split(","): - metrics_results.add(metric_dict[metric.strip()](gt_text, ocr_text)) - generate_json_report(gt, ocr, report_prefix, metrics_results) + metric = metric.strip() + if metric not in metric_dict.keys(): + raise ValueError(f"Unknown metric '{metric}'.") + metrics_results.add(metric_dict[metric](gt_text, ocr_text)) - html_report = True - if html_report: - generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results) + generate_json_report(gt, ocr, report_prefix, metrics_results) + generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results) @click.command() diff --git a/qurator/dinglehopper/tests/test_integ_cli.py b/qurator/dinglehopper/tests/test_integ_cli.py new file mode 100644 index 0000000..1769736 --- /dev/null +++ b/qurator/dinglehopper/tests/test_integ_cli.py @@ -0,0 +1,18 @@ +import pytest + +from .util import working_directory +from ..cli import process + + +@pytest.mark.integration +def test_cli_unknown_metric(tmp_path): + """Test that unknown metrics are handled appropriately.""" + + with working_directory(str(tmp_path)): + with open("gt.txt", "w") as gtf: + gtf.write("") + with open("ocr.txt", "w") as ocrf: + ocrf.write("") + + with pytest.raises(ValueError, match="Unknown metric 'unknown'."): + process("gt.txt", "ocr.txt", "report", metrics="cer,unknown")