Implemented new metrics behaviour

pull/60/head
Benjamin Rosemann 4 years ago
parent 9f5112f8f6
commit 06468a436e

@ -34,20 +34,23 @@ Usage: dinglehopper [OPTIONS] GT OCR [REPORT_PREFIX]
dinglehopper detects if GT/OCR are ALTO or PAGE XML documents to extract
their text and falls back to plain text if no ALTO or PAGE is detected.
The files GT and OCR are usually a ground truth document and the result of
an OCR software, but you may use dinglehopper to compare two OCR results.
In that case, use --no-metrics to disable the then meaningless metrics and
also change the color scheme from green/red to blue.
The files GT and OCR are usually a ground truth document and the result of
an OCR software, but you may use dinglehopper to compare two OCR results. In
that case, use --metrics='' to disable the then meaningless metrics and also
change the color scheme from green/red to blue.
The comparison report will be written to $REPORT_PREFIX.{html,json}, where
$REPORT_PREFIX defaults to "report". The reports include the character
error rate (CER) and the word error rate (WER).
$REPORT_PREFIX defaults to "report". Depending on your configuration the
reports include the character error rate (CA|CER), the word error rate (WA|WER),
the bag of chars accuracy (BoC), and the bag of words accuracy (BoW).
The metrics can be chosen via a comma separated combination of their acronyms
like "--metrics=ca,wer,boc,bow".
By default, the text of PAGE files is extracted on 'region' level. You may
use "--textequiv-level line" to extract from the level of TextLine tags.
Options:
--metrics / --no-metrics Enable/disable metrics and green/red
--metrics Enable different metrics like ca|cer, wa|wer, boc and bow.
--textequiv-level LEVEL PAGE TextEquiv level to extract text from
--progress Show progress bar
--help Show this message and exit.
@ -80,12 +83,12 @@ The OCR-D processor has these parameters:
| Parameter | Meaning |
| ------------------------- | ------------------------------------------------------------------- |
| `-P metrics false` | Disable metrics and the green-red color scheme (default: enabled) |
| `-P metrics cer,wer` | Enable character error rate and word error rate (default) |
| `-P textequiv_level line` | (PAGE) Extract text from TextLine level (default: TextRegion level) |
For example:
~~~
ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics false
ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics cer,wer
~~~
Developer information

@ -10,9 +10,10 @@ from uniseg.graphemecluster import grapheme_clusters
from .align import seq_align
from .config import Config
from .extracted_text import ExtractedText
from .metrics.character_error_rate import character_error_rate_n
from .metrics.word_error_rate import word_error_rate_n, words_normalized
from .ocr_files import extract
from .metrics import bag_of_chars_accuracy, bag_of_words_accuracy, character_accuracy, \
word_accuracy
from .normalize import words_normalized
from .ocr_files import text
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
@ -85,9 +86,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
)
def generate_html_report(
gt, ocr, gt_text, ocr_text, report_prefix, metrics, cer, n_characters, wer, n_words
):
def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results):
char_diff_report = gen_diff_report(
gt_text, ocr_text, css_prefix="c", joiner="", none="·"
)
@ -112,57 +111,50 @@ def generate_html_report(
template.stream(
gt=gt,
ocr=ocr,
cer=cer,
n_characters=n_characters,
wer=wer,
n_words=n_words,
char_diff_report=char_diff_report,
word_diff_report=word_diff_report,
metrics=metrics,
metrics_results=metrics_results,
).dump(out_fn)
def generate_json_report(
gt, ocr, report_prefix, metrics, cer, n_characters, wer, n_words
):
json_dict = {"gt": gt, "ocr": ocr, "n_characters": n_characters, "n_words": n_words}
if metrics:
json_dict = {**json_dict, "cer": cer, "wer": wer}
with open(f"{report_prefix}.json", 'w') as fp:
def generate_json_report(gt, ocr, report_prefix, metrics_results):
json_dict = {"gt": gt, "ocr": ocr}
for result in metrics_results:
json_dict[result.metric] = {
key: value for key, value in result.get_dict().items() if key != "metric"
}
print(json_dict)
with open(f"{report_prefix}.json", "w") as fp:
json.dump(json_dict, fp)
def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"):
def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="region"):
"""Check OCR result against GT.
The @click decorators change the signature of the decorated functions,
so we keep this undecorated version and use Click on a wrapper.
"""
gt_text = extract(gt, textequiv_level=textequiv_level)
ocr_text = extract(ocr, textequiv_level=textequiv_level)
cer, n_characters = character_error_rate_n(gt_text, ocr_text)
wer, n_words = word_error_rate_n(gt_text, ocr_text)
gt_text = text(gt, textequiv_level=textequiv_level)
ocr_text = text(ocr, textequiv_level=textequiv_level)
generate_json_report(
gt, ocr, report_prefix, metrics, cer, n_characters, wer, n_words
)
metrics_results = set()
if metrics:
metric_dict = {
"ca": character_accuracy,
"cer": character_accuracy,
"wa": word_accuracy,
"wer": word_accuracy,
"boc": bag_of_chars_accuracy,
"bow": bag_of_words_accuracy,
}
for metric in metrics.split(","):
metrics_results.add(metric_dict[metric.strip()](gt_text, ocr_text))
generate_json_report(gt, ocr, report_prefix, metrics_results)
html_report = True
if html_report:
generate_html_report(
gt,
ocr,
gt_text,
ocr_text,
report_prefix,
metrics,
cer,
n_characters,
wer,
n_words,
)
generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results)
@click.command()
@ -170,7 +162,9 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"):
@click.argument("ocr", type=click.Path(exists=True))
@click.argument("report_prefix", type=click.Path(), default="report")
@click.option(
"--metrics/--no-metrics", default=True, help="Enable/disable metrics and green/red"
"--metrics",
default="cer,wer",
help="Enable different metrics like cer, wer, boc and bow.",
)
@click.option(
"--textequiv-level",
@ -188,12 +182,15 @@ def main(gt, ocr, report_prefix, metrics, textequiv_level, progress):
The files GT and OCR are usually a ground truth document and the result of
an OCR software, but you may use dinglehopper to compare two OCR results. In
that case, use --no-metrics to disable the then meaningless metrics and also
that case, use --metrics='' to disable the then meaningless metrics and also
change the color scheme from green/red to blue.
The comparison report will be written to $REPORT_PREFIX.{html,json}, where
$REPORT_PREFIX defaults to "report". The reports include the character error
rate (CER) and the word error rate (WER).
$REPORT_PREFIX defaults to "report". Depending on your configuration the
reports include the character error rate (CA|CER), the word error rate (WA|WER),
the bag of chars accuracy (BoC), and the bag of words accuracy (BoW).
The metrics can be chosen via a comma separated combination of their acronyms
like "--metrics=ca,wer,boc,bow".
By default, the text of PAGE files is extracted on 'region' level. You may
use "--textequiv-level line" to extract from the level of TextLine tags.

@ -4,11 +4,9 @@ from functools import lru_cache, partial
from typing import Sequence, Tuple
import numpy as np
from multimethod import multimethod
from tqdm import tqdm
from .config import Config
from .extracted_text import ExtractedText
from .normalize import chars_normalized
@ -74,7 +72,6 @@ def levenshtein_matrix_cache_clear():
_levenshtein_matrix.cache_clear()
@multimethod
def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings
@ -86,11 +83,6 @@ def distance(s1: str, s2: str):
return levenshtein(seq1, seq2)
@multimethod
def distance(s1: ExtractedText, s2: ExtractedText):
return distance(s1.text, s2.text)
def seq_editops(seq1, seq2):
"""
Return sequence of edit operations transforming one sequence to another.

@ -1,5 +1,5 @@
from .bag_of_chars_accuracy import *
from .bag_of_words_accuracy import *
from .character_error_rate import *
from .utils import Weights
from .word_error_rate import *
from .character_accuracy import *
from .utils import MetricResult, Weights
from .word_accuracy import *

@ -7,7 +7,7 @@ from .utils import bag_accuracy, MetricResult, Weights
def bag_of_chars_accuracy(
reference: str, compared: str, weights: Weights
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult:
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))

@ -5,7 +5,7 @@ from ..normalize import words_normalized
def bag_of_words_accuracy(
reference: str, compared: str, weights: Weights
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult:
reference_words = Counter(words_normalized(reference))
compared_words = Counter(words_normalized(compared))

@ -0,0 +1,28 @@
from __future__ import division
from .utils import MetricResult, Weights
from .. import distance
from ..normalize import chars_normalized
def character_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 1, 1)
) -> MetricResult:
"""Compute character accuracy and error rate.
:return: NamedTuple representing the results of this metric.
"""
weighted_errors = distance(reference, compared)
n_ref = len(chars_normalized(reference))
n_cmp = len(chars_normalized(compared))
return MetricResult(
metric=character_accuracy.__name__,
weights=weights,
weighted_errors=int(weighted_errors),
reference_elements=n_ref,
compared_elements=n_cmp,
)
# XXX Should we really count newlines here?

@ -1,46 +0,0 @@
from __future__ import division
from typing import Tuple
from multimethod import multimethod
from .. import distance
from ..extracted_text import ExtractedText
from ..normalize import chars_normalized
@multimethod
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
"""
Compute character error rate.
:return: character error rate and length of the reference
"""
d = distance(reference, compared)
n = len(chars_normalized(reference))
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
# XXX Should we really count newlines here?
@multimethod
def character_error_rate_n(
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
return character_error_rate_n(reference.text, compared.text)
def character_error_rate(reference, compared) -> float:
"""
Compute character error rate.
:return: character error rate
"""
cer, _ = character_error_rate_n(reference, compared)
return cer

@ -1,5 +1,5 @@
from collections import Counter
from typing import NamedTuple
from typing import Dict, NamedTuple
class Weights(NamedTuple):
@ -25,10 +25,27 @@ class MetricResult(NamedTuple):
@property
def error_rate(self) -> float:
if self.reference_elements <= 0:
if self.reference_elements <= 0 and self.compared_elements <= 0:
return 0
elif self.reference_elements <= 0:
return float("inf")
return self.weighted_errors / self.reference_elements
def get_dict(self) -> Dict:
"""Combines the properties to a dictionary.
We deviate from the builtin _asdict() function by including our properties.
"""
return {
**{
key: value
for key, value in self._asdict().items()
},
"accuracy": self.accuracy,
"error_rate": self.error_rate,
"weights": self.weights._asdict(),
}
def bag_accuracy(
reference: Counter, compared: Counter, weights: Weights
@ -44,7 +61,7 @@ def bag_accuracy(
:param reference: Bag used as reference (ground truth).
:param compared: Bag used to compare (ocr).
:param weights: Weights/costs for editing operations.
:return: Tuple representing the results of this metric.
:return: NamedTuple representing the results of this metric.
"""
n_ref = sum(reference.values())
n_cmp = sum(compared.values())

@ -0,0 +1,22 @@
from .utils import MetricResult, Weights
from ..edit_distance import levenshtein
from ..normalize import words_normalized
def word_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 1, 1)
) -> MetricResult:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
weighted_errors = levenshtein(reference_seq, compared_seq)
n_ref = len(reference_seq)
n_cmp = len(compared_seq)
return MetricResult(
metric=word_accuracy.__name__,
weights=weights,
weighted_errors=int(weighted_errors),
reference_elements=n_ref,
compared_elements=n_cmp,
)

@ -1,43 +0,0 @@
from __future__ import division
from typing import Iterable, Tuple
from multimethod import multimethod
from ..edit_distance import levenshtein
from ..extracted_text import ExtractedText
from ..normalize import words_normalized
@multimethod
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
return word_error_rate_n(reference_seq, compared_seq)
@multimethod
def word_error_rate_n(
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text)
@multimethod
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)
d = levenshtein(reference_seq, compared_seq)
n = len(reference_seq)
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
def word_error_rate(reference, compared) -> float:
wer, _ = word_error_rate_n(reference, compared)
return wer

@ -163,8 +163,8 @@ def extract(filename, *, textequiv_level="region"):
return alto_extract(tree)
def text(filename):
return extract(filename).text
def text(filename, *args, **kwargs):
return extract(filename, *args, **kwargs).text
if __name__ == "__main__":

@ -19,9 +19,10 @@
],
"parameters": {
"metrics": {
"type": "boolean",
"default": true,
"description": "Enable/disable metrics and green/red"
"type": "string",
"enum": ["", "boc", "boc,bow", "bow", "ca", "ca,boc", "ca,boc,bow", "ca,bow", "ca,wa", "ca,wa,boc", "ca,wa,boc,bow", "ca,wa,bow", "ca,wer", "ca,wer,boc", "ca,wer,boc,bow", "ca,wer,bow", "cer", "cer,boc", "cer,boc,bow", "cer,bow", "cer,wa", "cer,wa,boc", "cer,wa,boc,bow", "cer,wa,bow", "cer,wer", "cer,wer,boc", "cer,wer,boc,bow", "cer,wer,bow", "wa", "wa,boc", "wa,boc,bow", "wa,bow", "wer", "wer,boc", "wer,boc,bow", "wer,bow"],
"default": "cer,wer",
"description": "Enable different metrics like ca|cer, wa|wer, boc and bow."
},
"textequiv_level": {
"type": "string",

@ -6,7 +6,7 @@
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
<style type="text/css">
{% if metrics %}
{% if metrics_results %}
.gt .diff {
color: green;
}
@ -38,10 +38,44 @@
{{ ocr }}
{% if metrics %}
{% if metrics_results %}
<h2>Metrics</h2>
<p>CER: {{ cer|round(4) }}</p>
<p>WER: {{ wer|round(4) }}</p>
<table class="table">
<thead>
<tr>
<th scope="col" rowspan="2">#</th>
<th scope="col" colspan="2">Results</th>
<th scope="col" colspan="2">Elements</th>
<th scope="col" colspan="2">Calculation</th>
</tr>
<tr>
<th scope="col">Acc</th>
<th scope="col">ER</th>
<th scope="col">Ref</th>
<th scope="col">Cmp</th>
<th scope="col">Errors</th>
<th scope="col">Weights</th>
</tr>
</thead>
<tbody>
{% for result in metrics_results %}
<tr>
<th scope="row">{{ result.metric.replace("_", " ") }}</th>
<td>{{ result.accuracy }}</td>
<td>{{ result.error_rate }}</td>
<td>{{ result.reference_elements }}</td>
<td>{{ result.compared_elements }}</td>
<td>{{ result.weighted_errors }}</td>
<td>d={{ result.weights.deletes }},i={{ result.weights.inserts }},r={{ result.weights.replacements }}</td>
</tr>
{% endfor %}
</tbody>
<tfoot>
<tr>
<td colspan="7">Legend: Acc. = Accuracy, ER = Error Rate, Ref. = Reference, Cmp. = Compared</td>
</tr>
</tfoot>
</table>
{% endif %}
<h2>Character differences</h2>

@ -0,0 +1,40 @@
import math
import unicodedata
from ...metrics import character_accuracy
def test_character_accuracy():
assert character_accuracy("a", "a").error_rate == 0
assert character_accuracy("a", "b").error_rate == 1 / 1
assert character_accuracy("Foo", "Bar").error_rate == 3 / 3
assert character_accuracy("Foo", "").error_rate == 3 / 3
assert character_accuracy("", "").error_rate == 0
assert math.isinf(character_accuracy("", "Foo").error_rate)
assert character_accuracy("Foo", "Food").error_rate == 1 / 3
assert character_accuracy("Fnord", "Food").error_rate == 2 / 5
assert character_accuracy("Müll", "Mull").error_rate == 1 / 4
assert character_accuracy("Abstand", "Sand").error_rate == 4 / 7
def test_character_accuracy_hard():
s1 = unicodedata.normalize("NFC", "Schlyñ lorem ipsum.")
s2 = unicodedata.normalize("NFD", "Schlyñ lorem ipsum!") # Different, decomposed!
assert character_accuracy(s1, s2).error_rate == 1 / 19
s1 = "Schlyñ"
assert (
len(s1) == 6
) # This ends with LATIN SMALL LETTER N WITH TILDE, so 6 code points
s2 = "Schlym̃"
assert (
len(s2) == 7
) # This, OTOH, ends with LATIN SMALL LETTER M + COMBINING TILDE, 7 code points
# Both strings have the same length in terms of grapheme clusters.
# So the CER should be symmetrical.
assert character_accuracy(s2, s1).error_rate == 1 / 6
assert character_accuracy(s1, s2).error_rate == 1 / 6

@ -1,41 +0,0 @@
from __future__ import division, print_function
import math
import unicodedata
from ...metrics import character_error_rate
def test_character_error_rate():
assert character_error_rate("a", "a") == 0
assert character_error_rate("a", "b") == 1 / 1
assert character_error_rate("Foo", "Bar") == 3 / 3
assert character_error_rate("Foo", "") == 3 / 3
assert character_error_rate("", "") == 0
assert math.isinf(character_error_rate("", "Foo"))
assert character_error_rate("Foo", "Food") == 1 / 3
assert character_error_rate("Fnord", "Food") == 2 / 5
assert character_error_rate("Müll", "Mull") == 1 / 4
assert character_error_rate("Abstand", "Sand") == 4 / 7
def test_character_error_rate_hard():
s1 = unicodedata.normalize("NFC", "Schlyñ lorem ipsum.")
s2 = unicodedata.normalize("NFD", "Schlyñ lorem ipsum!") # Different, decomposed!
assert character_error_rate(s1, s2) == 1 / 19
s1 = "Schlyñ"
assert (
len(s1) == 6
) # This ends with LATIN SMALL LETTER N WITH TILDE, so 6 code points
s2 = "Schlym̃"
assert (
len(s2) == 7
) # This, OTOH, ends with LATIN SMALL LETTER M + COMBINING TILDE, 7 code points
# Both strings have the same length in terms of grapheme clusters. So the CER should be symmetrical.
assert character_error_rate(s2, s1) == 1 / 6
assert character_error_rate(s1, s2) == 1 / 6

@ -7,7 +7,7 @@ from lxml import etree as ET
from uniseg.graphemecluster import grapheme_clusters
from ... import page_text, alto_text
from ...metrics import character_error_rate
from ...metrics import character_accuracy
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
@ -22,7 +22,7 @@ def test_character_error_rate_between_page_files():
gt_len = len(list(grapheme_clusters(gt)))
expected_cer = 2 / gt_len
assert character_error_rate(gt, ocr) == expected_cer
assert character_accuracy(gt, ocr).error_rate == expected_cer
@pytest.mark.integration
@ -39,7 +39,7 @@ def test_character_error_rate_between_page_alto():
)
assert gt == ocr
assert character_error_rate(gt, ocr) == 0
assert character_accuracy(gt, ocr).error_rate == 0
@pytest.mark.integration
@ -57,4 +57,4 @@ def test_character_error_rate_between_page_alto_2():
)
)
assert character_error_rate(gt, ocr) == 8 / 591 # Manually verified
assert character_accuracy(gt, ocr).error_rate == 8 / 591 # Manually verified

@ -6,7 +6,7 @@ import pytest
from lxml import etree as ET
from ... import alto_text, page_text
from ...metrics import word_error_rate
from ...metrics import word_accuracy
from ...normalize import words
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
@ -24,7 +24,7 @@ def test_word_error_rate_between_page_files():
assert len(list(words(gt))) == gt_word_count
ocr = page_text(ET.parse(os.path.join(data_dir, "test-fake-ocr.page2018.xml")))
assert word_error_rate(gt, ocr) == 2 / gt_word_count
assert word_accuracy(gt, ocr).error_rate == 2 / gt_word_count
@pytest.mark.integration
@ -41,7 +41,7 @@ def test_word_error_rate_between_page_alto():
)
assert gt == ocr
assert word_error_rate(gt, ocr) == 0
assert word_accuracy(gt, ocr).error_rate == 0
@pytest.mark.integration
@ -66,5 +66,5 @@ def test_word_error_rate_between_page_alto_2():
)
assert (
word_error_rate(gt, ocr) == 7 / gt_word_count
word_accuracy(gt, ocr).error_rate == 7 / gt_word_count
) # Manually verified, 6 words are wrong, 1 got split (=2 errors)

@ -1,8 +1,6 @@
from __future__ import division, print_function
import math
from ...metrics import word_error_rate
from ...metrics import word_accuracy
from ...normalize import words
@ -55,33 +53,44 @@ def test_words_private_use_area():
def test_word_error_rate():
assert (
word_error_rate("Dies ist ein Beispielsatz!", "Dies ist ein Beispielsatz!") == 0
word_accuracy(
"Dies ist ein Beispielsatz!", "Dies ist ein Beispielsatz!"
).error_rate
== 0
)
assert (
word_error_rate("Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz!")
word_accuracy(
"Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz!"
).error_rate
== 0
)
assert (
word_error_rate("Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz.")
word_accuracy(
"Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz."
).error_rate
== 0
)
assert (
word_error_rate("Dies ist ein Beispielsatz!", "Dies ist ein Beispielsarz:")
word_accuracy(
"Dies ist ein Beispielsatz!", "Dies ist ein Beispielsarz:"
).error_rate
== 1 / 4
)
assert (
word_error_rate("Dies ist ein Beispielsatz!", "Dies ein ist Beispielsatz!")
word_accuracy(
"Dies ist ein Beispielsatz!", "Dies ein ist Beispielsatz!"
).error_rate
== 2 / 4
)
assert word_error_rate("Dies ist ein Beispielsatz!", "") == 4 / 4
assert math.isinf(word_error_rate("", "Dies ist ein Beispielsatz!"))
assert word_error_rate("", "") == 0
assert word_accuracy("Dies ist ein Beispielsatz!", "").error_rate == 4 / 4
assert math.isinf(word_accuracy("", "Dies ist ein Beispielsatz!").error_rate)
assert word_accuracy("", "").error_rate == 0
assert (
word_error_rate(
word_accuracy(
"Schlyñ lorem ipsum dolor sit amet,", "Schlym̃ lorem ipsum dolor sit amet."
)
).error_rate
== 1 / 6
)

@ -1,14 +1,43 @@
import json
from itertools import combinations
import pytest
from .util import working_directory
from .util import working_directory
from ..cli import process
METRIC_DICT = {
"": "",
"ca": "character_accuracy",
"cer": "character_accuracy",
"wa": "word_accuracy",
"wer": "word_accuracy",
"boc": "bag_of_chars_accuracy",
"bow": "bag_of_words_accuracy",
}
@pytest.mark.integration
def test_cli_json(tmp_path):
"""Test that the cli/process() yields a loadable JSON report"""
@pytest.mark.parametrize(
"metrics",
[
*(("",), ("cer",), ("wer",), ("ca",), ("wa",), ("boc",), ("bow",)),
*combinations(("ca", "wa", "boc", "bow"), 2),
*combinations(("cer", "wer", "boc", "bow"), 2),
*combinations(("ca", "wa", "boc", "bow"), 3),
*combinations(("cer", "wer", "boc", "bow"), 3),
("ca", "wa", "boc", "bow"),
("cer", "wer", "boc", "bow"),
],
)
def test_cli_json(metrics, tmp_path):
"""Test that the cli/process() yields a loadable JSON report."""
expected_values = {
"character_accuracy": 0.2,
"word_accuracy": 1.0,
"bag_of_chars_accuracy": 0.2,
"bag_of_words_accuracy": 1.0,
}
with working_directory(str(tmp_path)):
with open("gt.txt", "w") as gtf:
@ -18,25 +47,38 @@ def test_cli_json(tmp_path):
with open("gt.txt", "r") as gtf:
print(gtf.read())
process("gt.txt", "ocr.txt", "report")
process("gt.txt", "ocr.txt", "report", metrics=",".join(metrics))
with open("report.json", "r") as jsonf:
print(jsonf.read())
with open("report.json", "r") as jsonf:
j = json.load(jsonf)
assert j["cer"] == pytest.approx(0.2)
metrics_translated = {METRIC_DICT[metric] for metric in metrics}
for metric, expected_value in expected_values.items():
if metric in metrics_translated:
assert j[metric]["error_rate"] == pytest.approx(expected_value)
else:
assert metric not in j.keys()
@pytest.mark.integration
def test_cli_json_cer_is_infinity(tmp_path):
"""Test that the cli/process() yields a loadable JSON report when CER == inf"""
@pytest.mark.parametrize(
"gt,ocr,err",
[("", "Not important", float("inf")), ("Lorem Ipsum", "Lorem Ipsum", 0)],
)
def test_cli_json_extremes(gt, ocr, err, tmp_path):
"""Test that the cli/process() yields a loadable JSON reports."""
with working_directory(str(tmp_path)):
with open("gt.txt", "w") as gtf:
gtf.write("") # Empty to yield CER == inf
gtf.write(gt)
with open("ocr.txt", "w") as ocrf:
ocrf.write("Not important")
ocrf.write(ocr)
process("gt.txt", "ocr.txt", "report")
process("gt.txt", "ocr.txt", "report", metrics="ca,wa,boc,bow")
with open("report.json", "r") as jsonf:
j = json.load(jsonf)
assert j["cer"] == pytest.approx(float("inf"))
for metric in set(METRIC_DICT.values()):
if not metric:
continue
assert j[metric]["error_rate"] == pytest.approx(err)

@ -7,5 +7,4 @@ colorama
MarkupSafe
ocrd >= 2.20.1
attrs
multimethod == 1.3 # latest version to officially support Python 3.5
tqdm

Loading…
Cancel
Save