Implemented new metrics behaviour

pull/60/head
Benjamin Rosemann 4 years ago
parent 9f5112f8f6
commit 06468a436e

@ -35,19 +35,22 @@ Usage: dinglehopper [OPTIONS] GT OCR [REPORT_PREFIX]
their text and falls back to plain text if no ALTO or PAGE is detected.
The files GT and OCR are usually a ground truth document and the result of
an OCR software, but you may use dinglehopper to compare two OCR results.
In that case, use --no-metrics to disable the then meaningless metrics and
also change the color scheme from green/red to blue.
an OCR software, but you may use dinglehopper to compare two OCR results. In
that case, use --metrics='' to disable the then meaningless metrics and also
change the color scheme from green/red to blue.
The comparison report will be written to $REPORT_PREFIX.{html,json}, where
$REPORT_PREFIX defaults to "report". The reports include the character
error rate (CER) and the word error rate (WER).
$REPORT_PREFIX defaults to "report". Depending on your configuration the
reports include the character error rate (CA|CER), the word error rate (WA|WER),
the bag of chars accuracy (BoC), and the bag of words accuracy (BoW).
The metrics can be chosen via a comma separated combination of their acronyms
like "--metrics=ca,wer,boc,bow".
By default, the text of PAGE files is extracted on 'region' level. You may
use "--textequiv-level line" to extract from the level of TextLine tags.
Options:
--metrics / --no-metrics Enable/disable metrics and green/red
--metrics Enable different metrics like ca|cer, wa|wer, boc and bow.
--textequiv-level LEVEL PAGE TextEquiv level to extract text from
--progress Show progress bar
--help Show this message and exit.
@ -80,12 +83,12 @@ The OCR-D processor has these parameters:
| Parameter | Meaning |
| ------------------------- | ------------------------------------------------------------------- |
| `-P metrics false` | Disable metrics and the green-red color scheme (default: enabled) |
| `-P metrics cer,wer` | Enable character error rate and word error rate (default) |
| `-P textequiv_level line` | (PAGE) Extract text from TextLine level (default: TextRegion level) |
For example:
~~~
ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics false
ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics cer,wer
~~~
Developer information

@ -10,9 +10,10 @@ from uniseg.graphemecluster import grapheme_clusters
from .align import seq_align
from .config import Config
from .extracted_text import ExtractedText
from .metrics.character_error_rate import character_error_rate_n
from .metrics.word_error_rate import word_error_rate_n, words_normalized
from .ocr_files import extract
from .metrics import bag_of_chars_accuracy, bag_of_words_accuracy, character_accuracy, \
word_accuracy
from .normalize import words_normalized
from .ocr_files import text
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
@ -85,9 +86,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
)
def generate_html_report(
gt, ocr, gt_text, ocr_text, report_prefix, metrics, cer, n_characters, wer, n_words
):
def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results):
char_diff_report = gen_diff_report(
gt_text, ocr_text, css_prefix="c", joiner="", none="·"
)
@ -112,57 +111,50 @@ def generate_html_report(
template.stream(
gt=gt,
ocr=ocr,
cer=cer,
n_characters=n_characters,
wer=wer,
n_words=n_words,
char_diff_report=char_diff_report,
word_diff_report=word_diff_report,
metrics=metrics,
metrics_results=metrics_results,
).dump(out_fn)
def generate_json_report(
gt, ocr, report_prefix, metrics, cer, n_characters, wer, n_words
):
json_dict = {"gt": gt, "ocr": ocr, "n_characters": n_characters, "n_words": n_words}
if metrics:
json_dict = {**json_dict, "cer": cer, "wer": wer}
with open(f"{report_prefix}.json", 'w') as fp:
def generate_json_report(gt, ocr, report_prefix, metrics_results):
json_dict = {"gt": gt, "ocr": ocr}
for result in metrics_results:
json_dict[result.metric] = {
key: value for key, value in result.get_dict().items() if key != "metric"
}
print(json_dict)
with open(f"{report_prefix}.json", "w") as fp:
json.dump(json_dict, fp)
def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"):
def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="region"):
"""Check OCR result against GT.
The @click decorators change the signature of the decorated functions,
so we keep this undecorated version and use Click on a wrapper.
"""
gt_text = extract(gt, textequiv_level=textequiv_level)
ocr_text = extract(ocr, textequiv_level=textequiv_level)
cer, n_characters = character_error_rate_n(gt_text, ocr_text)
wer, n_words = word_error_rate_n(gt_text, ocr_text)
gt_text = text(gt, textequiv_level=textequiv_level)
ocr_text = text(ocr, textequiv_level=textequiv_level)
generate_json_report(
gt, ocr, report_prefix, metrics, cer, n_characters, wer, n_words
)
metrics_results = set()
if metrics:
metric_dict = {
"ca": character_accuracy,
"cer": character_accuracy,
"wa": word_accuracy,
"wer": word_accuracy,
"boc": bag_of_chars_accuracy,
"bow": bag_of_words_accuracy,
}
for metric in metrics.split(","):
metrics_results.add(metric_dict[metric.strip()](gt_text, ocr_text))
generate_json_report(gt, ocr, report_prefix, metrics_results)
html_report = True
if html_report:
generate_html_report(
gt,
ocr,
gt_text,
ocr_text,
report_prefix,
metrics,
cer,
n_characters,
wer,
n_words,
)
generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results)
@click.command()
@ -170,7 +162,9 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"):
@click.argument("ocr", type=click.Path(exists=True))
@click.argument("report_prefix", type=click.Path(), default="report")
@click.option(
"--metrics/--no-metrics", default=True, help="Enable/disable metrics and green/red"
"--metrics",
default="cer,wer",
help="Enable different metrics like cer, wer, boc and bow.",
)
@click.option(
"--textequiv-level",
@ -188,12 +182,15 @@ def main(gt, ocr, report_prefix, metrics, textequiv_level, progress):
The files GT and OCR are usually a ground truth document and the result of
an OCR software, but you may use dinglehopper to compare two OCR results. In
that case, use --no-metrics to disable the then meaningless metrics and also
that case, use --metrics='' to disable the then meaningless metrics and also
change the color scheme from green/red to blue.
The comparison report will be written to $REPORT_PREFIX.{html,json}, where
$REPORT_PREFIX defaults to "report". The reports include the character error
rate (CER) and the word error rate (WER).
$REPORT_PREFIX defaults to "report". Depending on your configuration the
reports include the character error rate (CA|CER), the word error rate (WA|WER),
the bag of chars accuracy (BoC), and the bag of words accuracy (BoW).
The metrics can be chosen via a comma separated combination of their acronyms
like "--metrics=ca,wer,boc,bow".
By default, the text of PAGE files is extracted on 'region' level. You may
use "--textequiv-level line" to extract from the level of TextLine tags.

@ -4,11 +4,9 @@ from functools import lru_cache, partial
from typing import Sequence, Tuple
import numpy as np
from multimethod import multimethod
from tqdm import tqdm
from .config import Config
from .extracted_text import ExtractedText
from .normalize import chars_normalized
@ -74,7 +72,6 @@ def levenshtein_matrix_cache_clear():
_levenshtein_matrix.cache_clear()
@multimethod
def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings
@ -86,11 +83,6 @@ def distance(s1: str, s2: str):
return levenshtein(seq1, seq2)
@multimethod
def distance(s1: ExtractedText, s2: ExtractedText):
return distance(s1.text, s2.text)
def seq_editops(seq1, seq2):
"""
Return sequence of edit operations transforming one sequence to another.

@ -1,5 +1,5 @@
from .bag_of_chars_accuracy import *
from .bag_of_words_accuracy import *
from .character_error_rate import *
from .utils import Weights
from .word_error_rate import *
from .character_accuracy import *
from .utils import MetricResult, Weights
from .word_accuracy import *

@ -7,7 +7,7 @@ from .utils import bag_accuracy, MetricResult, Weights
def bag_of_chars_accuracy(
reference: str, compared: str, weights: Weights
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult:
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))

@ -5,7 +5,7 @@ from ..normalize import words_normalized
def bag_of_words_accuracy(
reference: str, compared: str, weights: Weights
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult:
reference_words = Counter(words_normalized(reference))
compared_words = Counter(words_normalized(compared))

@ -0,0 +1,28 @@
from __future__ import division
from .utils import MetricResult, Weights
from .. import distance
from ..normalize import chars_normalized
def character_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 1, 1)
) -> MetricResult:
"""Compute character accuracy and error rate.
:return: NamedTuple representing the results of this metric.
"""
weighted_errors = distance(reference, compared)
n_ref = len(chars_normalized(reference))
n_cmp = len(chars_normalized(compared))
return MetricResult(
metric=character_accuracy.__name__,
weights=weights,
weighted_errors=int(weighted_errors),
reference_elements=n_ref,
compared_elements=n_cmp,
)
# XXX Should we really count newlines here?

@ -1,46 +0,0 @@
from __future__ import division
from typing import Tuple
from multimethod import multimethod
from .. import distance
from ..extracted_text import ExtractedText
from ..normalize import chars_normalized
@multimethod
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
"""
Compute character error rate.
:return: character error rate and length of the reference
"""
d = distance(reference, compared)
n = len(chars_normalized(reference))
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
# XXX Should we really count newlines here?
@multimethod
def character_error_rate_n(
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
return character_error_rate_n(reference.text, compared.text)
def character_error_rate(reference, compared) -> float:
"""
Compute character error rate.
:return: character error rate
"""
cer, _ = character_error_rate_n(reference, compared)
return cer

@ -1,5 +1,5 @@
from collections import Counter
from typing import NamedTuple
from typing import Dict, NamedTuple
class Weights(NamedTuple):
@ -25,10 +25,27 @@ class MetricResult(NamedTuple):
@property
def error_rate(self) -> float:
if self.reference_elements <= 0:
if self.reference_elements <= 0 and self.compared_elements <= 0:
return 0
elif self.reference_elements <= 0:
return float("inf")
return self.weighted_errors / self.reference_elements
def get_dict(self) -> Dict:
"""Combines the properties to a dictionary.
We deviate from the builtin _asdict() function by including our properties.
"""
return {
**{
key: value
for key, value in self._asdict().items()
},
"accuracy": self.accuracy,
"error_rate": self.error_rate,
"weights": self.weights._asdict(),
}
def bag_accuracy(
reference: Counter, compared: Counter, weights: Weights
@ -44,7 +61,7 @@ def bag_accuracy(
:param reference: Bag used as reference (ground truth).
:param compared: Bag used to compare (ocr).
:param weights: Weights/costs for editing operations.
:return: Tuple representing the results of this metric.
:return: NamedTuple representing the results of this metric.
"""
n_ref = sum(reference.values())
n_cmp = sum(compared.values())

@ -0,0 +1,22 @@
from .utils import MetricResult, Weights
from ..edit_distance import levenshtein
from ..normalize import words_normalized
def word_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 1, 1)
) -> MetricResult:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
weighted_errors = levenshtein(reference_seq, compared_seq)
n_ref = len(reference_seq)
n_cmp = len(compared_seq)
return MetricResult(
metric=word_accuracy.__name__,
weights=weights,
weighted_errors=int(weighted_errors),
reference_elements=n_ref,
compared_elements=n_cmp,
)

@ -1,43 +0,0 @@
from __future__ import division
from typing import Iterable, Tuple
from multimethod import multimethod
from ..edit_distance import levenshtein
from ..extracted_text import ExtractedText
from ..normalize import words_normalized
@multimethod
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
return word_error_rate_n(reference_seq, compared_seq)
@multimethod
def word_error_rate_n(
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text)
@multimethod
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)
d = levenshtein(reference_seq, compared_seq)
n = len(reference_seq)
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
def word_error_rate(reference, compared) -> float:
wer, _ = word_error_rate_n(reference, compared)
return wer

@ -163,8 +163,8 @@ def extract(filename, *, textequiv_level="region"):
return alto_extract(tree)
def text(filename):
return extract(filename).text
def text(filename, *args, **kwargs):
return extract(filename, *args, **kwargs).text
if __name__ == "__main__":

@ -19,9 +19,10 @@
],
"parameters": {
"metrics": {
"type": "boolean",
"default": true,
"description": "Enable/disable metrics and green/red"
"type": "string",
"enum": ["", "boc", "boc,bow", "bow", "ca", "ca,boc", "ca,boc,bow", "ca,bow", "ca,wa", "ca,wa,boc", "ca,wa,boc,bow", "ca,wa,bow", "ca,wer", "ca,wer,boc", "ca,wer,boc,bow", "ca,wer,bow", "cer", "cer,boc", "cer,boc,bow", "cer,bow", "cer,wa", "cer,wa,boc", "cer,wa,boc,bow", "cer,wa,bow", "cer,wer", "cer,wer,boc", "cer,wer,boc,bow", "cer,wer,bow", "wa", "wa,boc", "wa,boc,bow", "wa,bow", "wer", "wer,boc", "wer,boc,bow", "wer,bow"],
"default": "cer,wer",
"description": "Enable different metrics like ca|cer, wa|wer, boc and bow."
},
"textequiv_level": {
"type": "string",

@ -6,7 +6,7 @@
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
<style type="text/css">
{% if metrics %}
{% if metrics_results %}
.gt .diff {
color: green;
}
@ -38,10 +38,44 @@
{{ ocr }}
{% if metrics %}
{% if metrics_results %}
<h2>Metrics</h2>
<p>CER: {{ cer|round(4) }}</p>
<p>WER: {{ wer|round(4) }}</p>
<table class="table">
<thead>
<tr>
<th scope="col" rowspan="2">#</th>
<th scope="col" colspan="2">Results</th>
<th scope="col" colspan="2">Elements</th>
<th scope="col" colspan="2">Calculation</th>
</tr>
<tr>
<th scope="col">Acc</th>
<th scope="col">ER</th>
<th scope="col">Ref</th>
<th scope="col">Cmp</th>
<th scope="col">Errors</th>
<th scope="col">Weights</th>
</tr>
</thead>
<tbody>
{% for result in metrics_results %}
<tr>
<th scope="row">{{ result.metric.replace("_", " ") }}</th>
<td>{{ result.accuracy }}</td>
<td>{{ result.error_rate }}</td>
<td>{{ result.reference_elements }}</td>
<td>{{ result.compared_elements }}</td>
<td>{{ result.weighted_errors }}</td>
<td>d={{ result.weights.deletes }},i={{ result.weights.inserts }},r={{ result.weights.replacements }}</td>
</tr>
{% endfor %}
</tbody>
<tfoot>
<tr>
<td colspan="7">Legend: Acc. = Accuracy, ER = Error Rate, Ref. = Reference, Cmp. = Compared</td>
</tr>
</tfoot>
</table>
{% endif %}
<h2>Character differences</h2>

@ -0,0 +1,40 @@
import math
import unicodedata
from ...metrics import character_accuracy
def test_character_accuracy():
assert character_accuracy("a", "a").error_rate == 0
assert character_accuracy("a", "b").error_rate == 1 / 1
assert character_accuracy("Foo", "Bar").error_rate == 3 / 3
assert character_accuracy("Foo", "").error_rate == 3 / 3
assert character_accuracy("", "").error_rate == 0
assert math.isinf(character_accuracy("", "Foo").error_rate)
assert character_accuracy("Foo", "Food").error_rate == 1 / 3
assert character_accuracy("Fnord", "Food").error_rate == 2 / 5
assert character_accuracy("Müll", "Mull").error_rate == 1 / 4
assert character_accuracy("Abstand", "Sand").error_rate == 4 / 7
def test_character_accuracy_hard():
s1 = unicodedata.normalize("NFC", "Schlyñ lorem ipsum.")
s2 = unicodedata.normalize("NFD", "Schlyñ lorem ipsum!") # Different, decomposed!
assert character_accuracy(s1, s2).error_rate == 1 / 19
s1 = "Schlyñ"
assert (
len(s1) == 6
) # This ends with LATIN SMALL LETTER N WITH TILDE, so 6 code points
s2 = "Schlym̃"
assert (
len(s2) == 7
) # This, OTOH, ends with LATIN SMALL LETTER M + COMBINING TILDE, 7 code points
# Both strings have the same length in terms of grapheme clusters.
# So the CER should be symmetrical.
assert character_accuracy(s2, s1).error_rate == 1 / 6
assert character_accuracy(s1, s2).error_rate == 1 / 6

@ -1,41 +0,0 @@
from __future__ import division, print_function
import math
import unicodedata
from ...metrics import character_error_rate
def test_character_error_rate():
assert character_error_rate("a", "a") == 0
assert character_error_rate("a", "b") == 1 / 1
assert character_error_rate("Foo", "Bar") == 3 / 3
assert character_error_rate("Foo", "") == 3 / 3
assert character_error_rate("", "") == 0
assert math.isinf(character_error_rate("", "Foo"))
assert character_error_rate("Foo", "Food") == 1 / 3
assert character_error_rate("Fnord", "Food") == 2 / 5
assert character_error_rate("Müll", "Mull") == 1 / 4
assert character_error_rate("Abstand", "Sand") == 4 / 7
def test_character_error_rate_hard():
s1 = unicodedata.normalize("NFC", "Schlyñ lorem ipsum.")
s2 = unicodedata.normalize("NFD", "Schlyñ lorem ipsum!") # Different, decomposed!
assert character_error_rate(s1, s2) == 1 / 19
s1 = "Schlyñ"
assert (
len(s1) == 6
) # This ends with LATIN SMALL LETTER N WITH TILDE, so 6 code points
s2 = "Schlym̃"
assert (
len(s2) == 7
) # This, OTOH, ends with LATIN SMALL LETTER M + COMBINING TILDE, 7 code points
# Both strings have the same length in terms of grapheme clusters. So the CER should be symmetrical.
assert character_error_rate(s2, s1) == 1 / 6
assert character_error_rate(s1, s2) == 1 / 6

@ -7,7 +7,7 @@ from lxml import etree as ET
from uniseg.graphemecluster import grapheme_clusters
from ... import page_text, alto_text
from ...metrics import character_error_rate
from ...metrics import character_accuracy
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
@ -22,7 +22,7 @@ def test_character_error_rate_between_page_files():
gt_len = len(list(grapheme_clusters(gt)))
expected_cer = 2 / gt_len
assert character_error_rate(gt, ocr) == expected_cer
assert character_accuracy(gt, ocr).error_rate == expected_cer
@pytest.mark.integration
@ -39,7 +39,7 @@ def test_character_error_rate_between_page_alto():
)
assert gt == ocr
assert character_error_rate(gt, ocr) == 0
assert character_accuracy(gt, ocr).error_rate == 0
@pytest.mark.integration
@ -57,4 +57,4 @@ def test_character_error_rate_between_page_alto_2():
)
)
assert character_error_rate(gt, ocr) == 8 / 591 # Manually verified
assert character_accuracy(gt, ocr).error_rate == 8 / 591 # Manually verified

@ -6,7 +6,7 @@ import pytest
from lxml import etree as ET
from ... import alto_text, page_text
from ...metrics import word_error_rate
from ...metrics import word_accuracy
from ...normalize import words
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
@ -24,7 +24,7 @@ def test_word_error_rate_between_page_files():
assert len(list(words(gt))) == gt_word_count
ocr = page_text(ET.parse(os.path.join(data_dir, "test-fake-ocr.page2018.xml")))
assert word_error_rate(gt, ocr) == 2 / gt_word_count
assert word_accuracy(gt, ocr).error_rate == 2 / gt_word_count
@pytest.mark.integration
@ -41,7 +41,7 @@ def test_word_error_rate_between_page_alto():
)
assert gt == ocr
assert word_error_rate(gt, ocr) == 0
assert word_accuracy(gt, ocr).error_rate == 0
@pytest.mark.integration
@ -66,5 +66,5 @@ def test_word_error_rate_between_page_alto_2():
)
assert (
word_error_rate(gt, ocr) == 7 / gt_word_count
word_accuracy(gt, ocr).error_rate == 7 / gt_word_count
) # Manually verified, 6 words are wrong, 1 got split (=2 errors)

@ -1,8 +1,6 @@
from __future__ import division, print_function
import math
from ...metrics import word_error_rate
from ...metrics import word_accuracy
from ...normalize import words
@ -55,33 +53,44 @@ def test_words_private_use_area():
def test_word_error_rate():
assert (
word_error_rate("Dies ist ein Beispielsatz!", "Dies ist ein Beispielsatz!") == 0
word_accuracy(
"Dies ist ein Beispielsatz!", "Dies ist ein Beispielsatz!"
).error_rate
== 0
)
assert (
word_error_rate("Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz!")
word_accuracy(
"Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz!"
).error_rate
== 0
)
assert (
word_error_rate("Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz.")
word_accuracy(
"Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz."
).error_rate
== 0
)
assert (
word_error_rate("Dies ist ein Beispielsatz!", "Dies ist ein Beispielsarz:")
word_accuracy(
"Dies ist ein Beispielsatz!", "Dies ist ein Beispielsarz:"
).error_rate
== 1 / 4
)
assert (
word_error_rate("Dies ist ein Beispielsatz!", "Dies ein ist Beispielsatz!")
word_accuracy(
"Dies ist ein Beispielsatz!", "Dies ein ist Beispielsatz!"
).error_rate
== 2 / 4
)
assert word_error_rate("Dies ist ein Beispielsatz!", "") == 4 / 4
assert math.isinf(word_error_rate("", "Dies ist ein Beispielsatz!"))
assert word_error_rate("", "") == 0
assert word_accuracy("Dies ist ein Beispielsatz!", "").error_rate == 4 / 4
assert math.isinf(word_accuracy("", "Dies ist ein Beispielsatz!").error_rate)
assert word_accuracy("", "").error_rate == 0
assert (
word_error_rate(
word_accuracy(
"Schlyñ lorem ipsum dolor sit amet,", "Schlym̃ lorem ipsum dolor sit amet."
)
).error_rate
== 1 / 6
)

@ -1,14 +1,43 @@
import json
from itertools import combinations
import pytest
from .util import working_directory
from .util import working_directory
from ..cli import process
METRIC_DICT = {
"": "",
"ca": "character_accuracy",
"cer": "character_accuracy",
"wa": "word_accuracy",
"wer": "word_accuracy",
"boc": "bag_of_chars_accuracy",
"bow": "bag_of_words_accuracy",
}
@pytest.mark.integration
def test_cli_json(tmp_path):
"""Test that the cli/process() yields a loadable JSON report"""
@pytest.mark.parametrize(
"metrics",
[
*(("",), ("cer",), ("wer",), ("ca",), ("wa",), ("boc",), ("bow",)),
*combinations(("ca", "wa", "boc", "bow"), 2),
*combinations(("cer", "wer", "boc", "bow"), 2),
*combinations(("ca", "wa", "boc", "bow"), 3),
*combinations(("cer", "wer", "boc", "bow"), 3),
("ca", "wa", "boc", "bow"),
("cer", "wer", "boc", "bow"),
],
)
def test_cli_json(metrics, tmp_path):
"""Test that the cli/process() yields a loadable JSON report."""
expected_values = {
"character_accuracy": 0.2,
"word_accuracy": 1.0,
"bag_of_chars_accuracy": 0.2,
"bag_of_words_accuracy": 1.0,
}
with working_directory(str(tmp_path)):
with open("gt.txt", "w") as gtf:
@ -18,25 +47,38 @@ def test_cli_json(tmp_path):
with open("gt.txt", "r") as gtf:
print(gtf.read())
process("gt.txt", "ocr.txt", "report")
process("gt.txt", "ocr.txt", "report", metrics=",".join(metrics))
with open("report.json", "r") as jsonf:
print(jsonf.read())
with open("report.json", "r") as jsonf:
j = json.load(jsonf)
assert j["cer"] == pytest.approx(0.2)
metrics_translated = {METRIC_DICT[metric] for metric in metrics}
for metric, expected_value in expected_values.items():
if metric in metrics_translated:
assert j[metric]["error_rate"] == pytest.approx(expected_value)
else:
assert metric not in j.keys()
@pytest.mark.integration
def test_cli_json_cer_is_infinity(tmp_path):
"""Test that the cli/process() yields a loadable JSON report when CER == inf"""
@pytest.mark.parametrize(
"gt,ocr,err",
[("", "Not important", float("inf")), ("Lorem Ipsum", "Lorem Ipsum", 0)],
)
def test_cli_json_extremes(gt, ocr, err, tmp_path):
"""Test that the cli/process() yields a loadable JSON reports."""
with working_directory(str(tmp_path)):
with open("gt.txt", "w") as gtf:
gtf.write("") # Empty to yield CER == inf
gtf.write(gt)
with open("ocr.txt", "w") as ocrf:
ocrf.write("Not important")
ocrf.write(ocr)
process("gt.txt", "ocr.txt", "report")
process("gt.txt", "ocr.txt", "report", metrics="ca,wa,boc,bow")
with open("report.json", "r") as jsonf:
j = json.load(jsonf)
assert j["cer"] == pytest.approx(float("inf"))
for metric in set(METRIC_DICT.values()):
if not metric:
continue
assert j[metric]["error_rate"] == pytest.approx(err)

@ -7,5 +7,4 @@ colorama
MarkupSafe
ocrd >= 2.20.1
attrs
multimethod == 1.3 # latest version to officially support Python 3.5
tqdm

Loading…
Cancel
Save