Error handling

pull/60/head
Benjamin Rosemann 4 years ago
parent 06468a436e
commit a44a3d4bf2

@ -123,7 +123,6 @@ def generate_json_report(gt, ocr, report_prefix, metrics_results):
json_dict[result.metric] = {
key: value for key, value in result.get_dict().items() if key != "metric"
}
print(json_dict)
with open(f"{report_prefix}.json", "w") as fp:
json.dump(json_dict, fp)
@ -149,12 +148,13 @@ def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="regio
"bow": bag_of_words_accuracy,
}
for metric in metrics.split(","):
metrics_results.add(metric_dict[metric.strip()](gt_text, ocr_text))
generate_json_report(gt, ocr, report_prefix, metrics_results)
metric = metric.strip()
if metric not in metric_dict.keys():
raise ValueError(f"Unknown metric '{metric}'.")
metrics_results.add(metric_dict[metric](gt_text, ocr_text))
html_report = True
if html_report:
generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results)
generate_json_report(gt, ocr, report_prefix, metrics_results)
generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results)
@click.command()

@ -0,0 +1,18 @@
import pytest
from .util import working_directory
from ..cli import process
@pytest.mark.integration
def test_cli_unknown_metric(tmp_path):
"""Test that unknown metrics are handled appropriately."""
with working_directory(str(tmp_path)):
with open("gt.txt", "w") as gtf:
gtf.write("")
with open("ocr.txt", "w") as ocrf:
ocrf.write("")
with pytest.raises(ValueError, match="Unknown metric 'unknown'."):
process("gt.txt", "ocr.txt", "report", metrics="cer,unknown")
Loading…
Cancel
Save