Implemented new metrics behaviour

pull/60/head
Benjamin Rosemann 4 years ago
parent 9f5112f8f6
commit 06468a436e

@ -34,20 +34,23 @@ Usage: dinglehopper [OPTIONS] GT OCR [REPORT_PREFIX]
dinglehopper detects if GT/OCR are ALTO or PAGE XML documents to extract dinglehopper detects if GT/OCR are ALTO or PAGE XML documents to extract
their text and falls back to plain text if no ALTO or PAGE is detected. their text and falls back to plain text if no ALTO or PAGE is detected.
The files GT and OCR are usually a ground truth document and the result of The files GT and OCR are usually a ground truth document and the result of
an OCR software, but you may use dinglehopper to compare two OCR results. an OCR software, but you may use dinglehopper to compare two OCR results. In
In that case, use --no-metrics to disable the then meaningless metrics and that case, use --metrics='' to disable the then meaningless metrics and also
also change the color scheme from green/red to blue. change the color scheme from green/red to blue.
The comparison report will be written to $REPORT_PREFIX.{html,json}, where The comparison report will be written to $REPORT_PREFIX.{html,json}, where
$REPORT_PREFIX defaults to "report". The reports include the character $REPORT_PREFIX defaults to "report". Depending on your configuration the
error rate (CER) and the word error rate (WER). reports include the character error rate (CA|CER), the word error rate (WA|WER),
the bag of chars accuracy (BoC), and the bag of words accuracy (BoW).
The metrics can be chosen via a comma separated combination of their acronyms
like "--metrics=ca,wer,boc,bow".
By default, the text of PAGE files is extracted on 'region' level. You may By default, the text of PAGE files is extracted on 'region' level. You may
use "--textequiv-level line" to extract from the level of TextLine tags. use "--textequiv-level line" to extract from the level of TextLine tags.
Options: Options:
--metrics / --no-metrics Enable/disable metrics and green/red --metrics Enable different metrics like ca|cer, wa|wer, boc and bow.
--textequiv-level LEVEL PAGE TextEquiv level to extract text from --textequiv-level LEVEL PAGE TextEquiv level to extract text from
--progress Show progress bar --progress Show progress bar
--help Show this message and exit. --help Show this message and exit.
@ -80,12 +83,12 @@ The OCR-D processor has these parameters:
| Parameter | Meaning | | Parameter | Meaning |
| ------------------------- | ------------------------------------------------------------------- | | ------------------------- | ------------------------------------------------------------------- |
| `-P metrics false` | Disable metrics and the green-red color scheme (default: enabled) | | `-P metrics cer,wer` | Enable character error rate and word error rate (default) |
| `-P textequiv_level line` | (PAGE) Extract text from TextLine level (default: TextRegion level) | | `-P textequiv_level line` | (PAGE) Extract text from TextLine level (default: TextRegion level) |
For example: For example:
~~~ ~~~
ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics false ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics cer,wer
~~~ ~~~
Developer information Developer information

@ -10,9 +10,10 @@ from uniseg.graphemecluster import grapheme_clusters
from .align import seq_align from .align import seq_align
from .config import Config from .config import Config
from .extracted_text import ExtractedText from .extracted_text import ExtractedText
from .metrics.character_error_rate import character_error_rate_n from .metrics import bag_of_chars_accuracy, bag_of_words_accuracy, character_accuracy, \
from .metrics.word_error_rate import word_error_rate_n, words_normalized word_accuracy
from .ocr_files import extract from .normalize import words_normalized
from .ocr_files import text
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
@ -85,9 +86,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
) )
def generate_html_report( def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results):
gt, ocr, gt_text, ocr_text, report_prefix, metrics, cer, n_characters, wer, n_words
):
char_diff_report = gen_diff_report( char_diff_report = gen_diff_report(
gt_text, ocr_text, css_prefix="c", joiner="", none="·" gt_text, ocr_text, css_prefix="c", joiner="", none="·"
) )
@ -112,57 +111,50 @@ def generate_html_report(
template.stream( template.stream(
gt=gt, gt=gt,
ocr=ocr, ocr=ocr,
cer=cer,
n_characters=n_characters,
wer=wer,
n_words=n_words,
char_diff_report=char_diff_report, char_diff_report=char_diff_report,
word_diff_report=word_diff_report, word_diff_report=word_diff_report,
metrics=metrics, metrics_results=metrics_results,
).dump(out_fn) ).dump(out_fn)
def generate_json_report( def generate_json_report(gt, ocr, report_prefix, metrics_results):
gt, ocr, report_prefix, metrics, cer, n_characters, wer, n_words json_dict = {"gt": gt, "ocr": ocr}
): for result in metrics_results:
json_dict = {"gt": gt, "ocr": ocr, "n_characters": n_characters, "n_words": n_words} json_dict[result.metric] = {
if metrics: key: value for key, value in result.get_dict().items() if key != "metric"
json_dict = {**json_dict, "cer": cer, "wer": wer} }
with open(f"{report_prefix}.json", 'w') as fp: print(json_dict)
with open(f"{report_prefix}.json", "w") as fp:
json.dump(json_dict, fp) json.dump(json_dict, fp)
def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"): def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="region"):
"""Check OCR result against GT. """Check OCR result against GT.
The @click decorators change the signature of the decorated functions, The @click decorators change the signature of the decorated functions,
so we keep this undecorated version and use Click on a wrapper. so we keep this undecorated version and use Click on a wrapper.
""" """
gt_text = extract(gt, textequiv_level=textequiv_level) gt_text = text(gt, textequiv_level=textequiv_level)
ocr_text = extract(ocr, textequiv_level=textequiv_level) ocr_text = text(ocr, textequiv_level=textequiv_level)
cer, n_characters = character_error_rate_n(gt_text, ocr_text)
wer, n_words = word_error_rate_n(gt_text, ocr_text)
generate_json_report( metrics_results = set()
gt, ocr, report_prefix, metrics, cer, n_characters, wer, n_words if metrics:
) metric_dict = {
"ca": character_accuracy,
"cer": character_accuracy,
"wa": word_accuracy,
"wer": word_accuracy,
"boc": bag_of_chars_accuracy,
"bow": bag_of_words_accuracy,
}
for metric in metrics.split(","):
metrics_results.add(metric_dict[metric.strip()](gt_text, ocr_text))
generate_json_report(gt, ocr, report_prefix, metrics_results)
html_report = True html_report = True
if html_report: if html_report:
generate_html_report( generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results)
gt,
ocr,
gt_text,
ocr_text,
report_prefix,
metrics,
cer,
n_characters,
wer,
n_words,
)
@click.command() @click.command()
@ -170,7 +162,9 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"):
@click.argument("ocr", type=click.Path(exists=True)) @click.argument("ocr", type=click.Path(exists=True))
@click.argument("report_prefix", type=click.Path(), default="report") @click.argument("report_prefix", type=click.Path(), default="report")
@click.option( @click.option(
"--metrics/--no-metrics", default=True, help="Enable/disable metrics and green/red" "--metrics",
default="cer,wer",
help="Enable different metrics like cer, wer, boc and bow.",
) )
@click.option( @click.option(
"--textequiv-level", "--textequiv-level",
@ -188,12 +182,15 @@ def main(gt, ocr, report_prefix, metrics, textequiv_level, progress):
The files GT and OCR are usually a ground truth document and the result of The files GT and OCR are usually a ground truth document and the result of
an OCR software, but you may use dinglehopper to compare two OCR results. In an OCR software, but you may use dinglehopper to compare two OCR results. In
that case, use --no-metrics to disable the then meaningless metrics and also that case, use --metrics='' to disable the then meaningless metrics and also
change the color scheme from green/red to blue. change the color scheme from green/red to blue.
The comparison report will be written to $REPORT_PREFIX.{html,json}, where The comparison report will be written to $REPORT_PREFIX.{html,json}, where
$REPORT_PREFIX defaults to "report". The reports include the character error $REPORT_PREFIX defaults to "report". Depending on your configuration the
rate (CER) and the word error rate (WER). reports include the character error rate (CA|CER), the word error rate (WA|WER),
the bag of chars accuracy (BoC), and the bag of words accuracy (BoW).
The metrics can be chosen via a comma separated combination of their acronyms
like "--metrics=ca,wer,boc,bow".
By default, the text of PAGE files is extracted on 'region' level. You may By default, the text of PAGE files is extracted on 'region' level. You may
use "--textequiv-level line" to extract from the level of TextLine tags. use "--textequiv-level line" to extract from the level of TextLine tags.

@ -4,11 +4,9 @@ from functools import lru_cache, partial
from typing import Sequence, Tuple from typing import Sequence, Tuple
import numpy as np import numpy as np
from multimethod import multimethod
from tqdm import tqdm from tqdm import tqdm
from .config import Config from .config import Config
from .extracted_text import ExtractedText
from .normalize import chars_normalized from .normalize import chars_normalized
@ -74,7 +72,6 @@ def levenshtein_matrix_cache_clear():
_levenshtein_matrix.cache_clear() _levenshtein_matrix.cache_clear()
@multimethod
def distance(s1: str, s2: str): def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings """Compute the Levenshtein edit distance between two Unicode strings
@ -86,11 +83,6 @@ def distance(s1: str, s2: str):
return levenshtein(seq1, seq2) return levenshtein(seq1, seq2)
@multimethod
def distance(s1: ExtractedText, s2: ExtractedText):
return distance(s1.text, s2.text)
def seq_editops(seq1, seq2): def seq_editops(seq1, seq2):
""" """
Return sequence of edit operations transforming one sequence to another. Return sequence of edit operations transforming one sequence to another.

@ -1,5 +1,5 @@
from .bag_of_chars_accuracy import * from .bag_of_chars_accuracy import *
from .bag_of_words_accuracy import * from .bag_of_words_accuracy import *
from .character_error_rate import * from .character_accuracy import *
from .utils import Weights from .utils import MetricResult, Weights
from .word_error_rate import * from .word_accuracy import *

@ -7,7 +7,7 @@ from .utils import bag_accuracy, MetricResult, Weights
def bag_of_chars_accuracy( def bag_of_chars_accuracy(
reference: str, compared: str, weights: Weights reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult: ) -> MetricResult:
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference))) reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared))) compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))

@ -5,7 +5,7 @@ from ..normalize import words_normalized
def bag_of_words_accuracy( def bag_of_words_accuracy(
reference: str, compared: str, weights: Weights reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult: ) -> MetricResult:
reference_words = Counter(words_normalized(reference)) reference_words = Counter(words_normalized(reference))
compared_words = Counter(words_normalized(compared)) compared_words = Counter(words_normalized(compared))

@ -0,0 +1,28 @@
from __future__ import division
from .utils import MetricResult, Weights
from .. import distance
from ..normalize import chars_normalized
def character_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 1, 1)
) -> MetricResult:
"""Compute character accuracy and error rate.
:return: NamedTuple representing the results of this metric.
"""
weighted_errors = distance(reference, compared)
n_ref = len(chars_normalized(reference))
n_cmp = len(chars_normalized(compared))
return MetricResult(
metric=character_accuracy.__name__,
weights=weights,
weighted_errors=int(weighted_errors),
reference_elements=n_ref,
compared_elements=n_cmp,
)
# XXX Should we really count newlines here?

@ -1,46 +0,0 @@
from __future__ import division
from typing import Tuple
from multimethod import multimethod
from .. import distance
from ..extracted_text import ExtractedText
from ..normalize import chars_normalized
@multimethod
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
"""
Compute character error rate.
:return: character error rate and length of the reference
"""
d = distance(reference, compared)
n = len(chars_normalized(reference))
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
# XXX Should we really count newlines here?
@multimethod
def character_error_rate_n(
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
return character_error_rate_n(reference.text, compared.text)
def character_error_rate(reference, compared) -> float:
"""
Compute character error rate.
:return: character error rate
"""
cer, _ = character_error_rate_n(reference, compared)
return cer

@ -1,5 +1,5 @@
from collections import Counter from collections import Counter
from typing import NamedTuple from typing import Dict, NamedTuple
class Weights(NamedTuple): class Weights(NamedTuple):
@ -25,10 +25,27 @@ class MetricResult(NamedTuple):
@property @property
def error_rate(self) -> float: def error_rate(self) -> float:
if self.reference_elements <= 0: if self.reference_elements <= 0 and self.compared_elements <= 0:
return 0
elif self.reference_elements <= 0:
return float("inf") return float("inf")
return self.weighted_errors / self.reference_elements return self.weighted_errors / self.reference_elements
def get_dict(self) -> Dict:
"""Combines the properties to a dictionary.
We deviate from the builtin _asdict() function by including our properties.
"""
return {
**{
key: value
for key, value in self._asdict().items()
},
"accuracy": self.accuracy,
"error_rate": self.error_rate,
"weights": self.weights._asdict(),
}
def bag_accuracy( def bag_accuracy(
reference: Counter, compared: Counter, weights: Weights reference: Counter, compared: Counter, weights: Weights
@ -44,7 +61,7 @@ def bag_accuracy(
:param reference: Bag used as reference (ground truth). :param reference: Bag used as reference (ground truth).
:param compared: Bag used to compare (ocr). :param compared: Bag used to compare (ocr).
:param weights: Weights/costs for editing operations. :param weights: Weights/costs for editing operations.
:return: Tuple representing the results of this metric. :return: NamedTuple representing the results of this metric.
""" """
n_ref = sum(reference.values()) n_ref = sum(reference.values())
n_cmp = sum(compared.values()) n_cmp = sum(compared.values())

@ -0,0 +1,22 @@
from .utils import MetricResult, Weights
from ..edit_distance import levenshtein
from ..normalize import words_normalized
def word_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 1, 1)
) -> MetricResult:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
weighted_errors = levenshtein(reference_seq, compared_seq)
n_ref = len(reference_seq)
n_cmp = len(compared_seq)
return MetricResult(
metric=word_accuracy.__name__,
weights=weights,
weighted_errors=int(weighted_errors),
reference_elements=n_ref,
compared_elements=n_cmp,
)

@ -1,43 +0,0 @@
from __future__ import division
from typing import Iterable, Tuple
from multimethod import multimethod
from ..edit_distance import levenshtein
from ..extracted_text import ExtractedText
from ..normalize import words_normalized
@multimethod
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
return word_error_rate_n(reference_seq, compared_seq)
@multimethod
def word_error_rate_n(
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text)
@multimethod
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)
d = levenshtein(reference_seq, compared_seq)
n = len(reference_seq)
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
def word_error_rate(reference, compared) -> float:
wer, _ = word_error_rate_n(reference, compared)
return wer

@ -163,8 +163,8 @@ def extract(filename, *, textequiv_level="region"):
return alto_extract(tree) return alto_extract(tree)
def text(filename): def text(filename, *args, **kwargs):
return extract(filename).text return extract(filename, *args, **kwargs).text
if __name__ == "__main__": if __name__ == "__main__":

@ -19,9 +19,10 @@
], ],
"parameters": { "parameters": {
"metrics": { "metrics": {
"type": "boolean", "type": "string",
"default": true, "enum": ["", "boc", "boc,bow", "bow", "ca", "ca,boc", "ca,boc,bow", "ca,bow", "ca,wa", "ca,wa,boc", "ca,wa,boc,bow", "ca,wa,bow", "ca,wer", "ca,wer,boc", "ca,wer,boc,bow", "ca,wer,bow", "cer", "cer,boc", "cer,boc,bow", "cer,bow", "cer,wa", "cer,wa,boc", "cer,wa,boc,bow", "cer,wa,bow", "cer,wer", "cer,wer,boc", "cer,wer,boc,bow", "cer,wer,bow", "wa", "wa,boc", "wa,boc,bow", "wa,bow", "wer", "wer,boc", "wer,boc,bow", "wer,bow"],
"description": "Enable/disable metrics and green/red" "default": "cer,wer",
"description": "Enable different metrics like ca|cer, wa|wer, boc and bow."
}, },
"textequiv_level": { "textequiv_level": {
"type": "string", "type": "string",

@ -6,7 +6,7 @@
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous"> <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
<style type="text/css"> <style type="text/css">
{% if metrics %} {% if metrics_results %}
.gt .diff { .gt .diff {
color: green; color: green;
} }
@ -38,10 +38,44 @@
{{ ocr }} {{ ocr }}
{% if metrics %} {% if metrics_results %}
<h2>Metrics</h2> <h2>Metrics</h2>
<p>CER: {{ cer|round(4) }}</p> <table class="table">
<p>WER: {{ wer|round(4) }}</p> <thead>
<tr>
<th scope="col" rowspan="2">#</th>
<th scope="col" colspan="2">Results</th>
<th scope="col" colspan="2">Elements</th>
<th scope="col" colspan="2">Calculation</th>
</tr>
<tr>
<th scope="col">Acc</th>
<th scope="col">ER</th>
<th scope="col">Ref</th>
<th scope="col">Cmp</th>
<th scope="col">Errors</th>
<th scope="col">Weights</th>
</tr>
</thead>
<tbody>
{% for result in metrics_results %}
<tr>
<th scope="row">{{ result.metric.replace("_", " ") }}</th>
<td>{{ result.accuracy }}</td>
<td>{{ result.error_rate }}</td>
<td>{{ result.reference_elements }}</td>
<td>{{ result.compared_elements }}</td>
<td>{{ result.weighted_errors }}</td>
<td>d={{ result.weights.deletes }},i={{ result.weights.inserts }},r={{ result.weights.replacements }}</td>
</tr>
{% endfor %}
</tbody>
<tfoot>
<tr>
<td colspan="7">Legend: Acc. = Accuracy, ER = Error Rate, Ref. = Reference, Cmp. = Compared</td>
</tr>
</tfoot>
</table>
{% endif %} {% endif %}
<h2>Character differences</h2> <h2>Character differences</h2>

@ -0,0 +1,40 @@
import math
import unicodedata
from ...metrics import character_accuracy
def test_character_accuracy():
assert character_accuracy("a", "a").error_rate == 0
assert character_accuracy("a", "b").error_rate == 1 / 1
assert character_accuracy("Foo", "Bar").error_rate == 3 / 3
assert character_accuracy("Foo", "").error_rate == 3 / 3
assert character_accuracy("", "").error_rate == 0
assert math.isinf(character_accuracy("", "Foo").error_rate)
assert character_accuracy("Foo", "Food").error_rate == 1 / 3
assert character_accuracy("Fnord", "Food").error_rate == 2 / 5
assert character_accuracy("Müll", "Mull").error_rate == 1 / 4
assert character_accuracy("Abstand", "Sand").error_rate == 4 / 7
def test_character_accuracy_hard():
s1 = unicodedata.normalize("NFC", "Schlyñ lorem ipsum.")
s2 = unicodedata.normalize("NFD", "Schlyñ lorem ipsum!") # Different, decomposed!
assert character_accuracy(s1, s2).error_rate == 1 / 19
s1 = "Schlyñ"
assert (
len(s1) == 6
) # This ends with LATIN SMALL LETTER N WITH TILDE, so 6 code points
s2 = "Schlym̃"
assert (
len(s2) == 7
) # This, OTOH, ends with LATIN SMALL LETTER M + COMBINING TILDE, 7 code points
# Both strings have the same length in terms of grapheme clusters.
# So the CER should be symmetrical.
assert character_accuracy(s2, s1).error_rate == 1 / 6
assert character_accuracy(s1, s2).error_rate == 1 / 6

@ -1,41 +0,0 @@
from __future__ import division, print_function
import math
import unicodedata
from ...metrics import character_error_rate
def test_character_error_rate():
assert character_error_rate("a", "a") == 0
assert character_error_rate("a", "b") == 1 / 1
assert character_error_rate("Foo", "Bar") == 3 / 3
assert character_error_rate("Foo", "") == 3 / 3
assert character_error_rate("", "") == 0
assert math.isinf(character_error_rate("", "Foo"))
assert character_error_rate("Foo", "Food") == 1 / 3
assert character_error_rate("Fnord", "Food") == 2 / 5
assert character_error_rate("Müll", "Mull") == 1 / 4
assert character_error_rate("Abstand", "Sand") == 4 / 7
def test_character_error_rate_hard():
s1 = unicodedata.normalize("NFC", "Schlyñ lorem ipsum.")
s2 = unicodedata.normalize("NFD", "Schlyñ lorem ipsum!") # Different, decomposed!
assert character_error_rate(s1, s2) == 1 / 19
s1 = "Schlyñ"
assert (
len(s1) == 6
) # This ends with LATIN SMALL LETTER N WITH TILDE, so 6 code points
s2 = "Schlym̃"
assert (
len(s2) == 7
) # This, OTOH, ends with LATIN SMALL LETTER M + COMBINING TILDE, 7 code points
# Both strings have the same length in terms of grapheme clusters. So the CER should be symmetrical.
assert character_error_rate(s2, s1) == 1 / 6
assert character_error_rate(s1, s2) == 1 / 6

@ -7,7 +7,7 @@ from lxml import etree as ET
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
from ... import page_text, alto_text from ... import page_text, alto_text
from ...metrics import character_error_rate from ...metrics import character_accuracy
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data") data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
@ -22,7 +22,7 @@ def test_character_error_rate_between_page_files():
gt_len = len(list(grapheme_clusters(gt))) gt_len = len(list(grapheme_clusters(gt)))
expected_cer = 2 / gt_len expected_cer = 2 / gt_len
assert character_error_rate(gt, ocr) == expected_cer assert character_accuracy(gt, ocr).error_rate == expected_cer
@pytest.mark.integration @pytest.mark.integration
@ -39,7 +39,7 @@ def test_character_error_rate_between_page_alto():
) )
assert gt == ocr assert gt == ocr
assert character_error_rate(gt, ocr) == 0 assert character_accuracy(gt, ocr).error_rate == 0
@pytest.mark.integration @pytest.mark.integration
@ -57,4 +57,4 @@ def test_character_error_rate_between_page_alto_2():
) )
) )
assert character_error_rate(gt, ocr) == 8 / 591 # Manually verified assert character_accuracy(gt, ocr).error_rate == 8 / 591 # Manually verified

@ -6,7 +6,7 @@ import pytest
from lxml import etree as ET from lxml import etree as ET
from ... import alto_text, page_text from ... import alto_text, page_text
from ...metrics import word_error_rate from ...metrics import word_accuracy
from ...normalize import words from ...normalize import words
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data") data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
@ -24,7 +24,7 @@ def test_word_error_rate_between_page_files():
assert len(list(words(gt))) == gt_word_count assert len(list(words(gt))) == gt_word_count
ocr = page_text(ET.parse(os.path.join(data_dir, "test-fake-ocr.page2018.xml"))) ocr = page_text(ET.parse(os.path.join(data_dir, "test-fake-ocr.page2018.xml")))
assert word_error_rate(gt, ocr) == 2 / gt_word_count assert word_accuracy(gt, ocr).error_rate == 2 / gt_word_count
@pytest.mark.integration @pytest.mark.integration
@ -41,7 +41,7 @@ def test_word_error_rate_between_page_alto():
) )
assert gt == ocr assert gt == ocr
assert word_error_rate(gt, ocr) == 0 assert word_accuracy(gt, ocr).error_rate == 0
@pytest.mark.integration @pytest.mark.integration
@ -66,5 +66,5 @@ def test_word_error_rate_between_page_alto_2():
) )
assert ( assert (
word_error_rate(gt, ocr) == 7 / gt_word_count word_accuracy(gt, ocr).error_rate == 7 / gt_word_count
) # Manually verified, 6 words are wrong, 1 got split (=2 errors) ) # Manually verified, 6 words are wrong, 1 got split (=2 errors)

@ -1,8 +1,6 @@
from __future__ import division, print_function
import math import math
from ...metrics import word_error_rate from ...metrics import word_accuracy
from ...normalize import words from ...normalize import words
@ -55,33 +53,44 @@ def test_words_private_use_area():
def test_word_error_rate(): def test_word_error_rate():
assert ( assert (
word_error_rate("Dies ist ein Beispielsatz!", "Dies ist ein Beispielsatz!") == 0 word_accuracy(
"Dies ist ein Beispielsatz!", "Dies ist ein Beispielsatz!"
).error_rate
== 0
) )
assert ( assert (
word_error_rate("Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz!") word_accuracy(
"Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz!"
).error_rate
== 0 == 0
) )
assert ( assert (
word_error_rate("Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz.") word_accuracy(
"Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz."
).error_rate
== 0 == 0
) )
assert ( assert (
word_error_rate("Dies ist ein Beispielsatz!", "Dies ist ein Beispielsarz:") word_accuracy(
"Dies ist ein Beispielsatz!", "Dies ist ein Beispielsarz:"
).error_rate
== 1 / 4 == 1 / 4
) )
assert ( assert (
word_error_rate("Dies ist ein Beispielsatz!", "Dies ein ist Beispielsatz!") word_accuracy(
"Dies ist ein Beispielsatz!", "Dies ein ist Beispielsatz!"
).error_rate
== 2 / 4 == 2 / 4
) )
assert word_error_rate("Dies ist ein Beispielsatz!", "") == 4 / 4 assert word_accuracy("Dies ist ein Beispielsatz!", "").error_rate == 4 / 4
assert math.isinf(word_error_rate("", "Dies ist ein Beispielsatz!")) assert math.isinf(word_accuracy("", "Dies ist ein Beispielsatz!").error_rate)
assert word_error_rate("", "") == 0 assert word_accuracy("", "").error_rate == 0
assert ( assert (
word_error_rate( word_accuracy(
"Schlyñ lorem ipsum dolor sit amet,", "Schlym̃ lorem ipsum dolor sit amet." "Schlyñ lorem ipsum dolor sit amet,", "Schlym̃ lorem ipsum dolor sit amet."
) ).error_rate
== 1 / 6 == 1 / 6
) )

@ -1,14 +1,43 @@
import json import json
from itertools import combinations
import pytest import pytest
from .util import working_directory
from .util import working_directory
from ..cli import process from ..cli import process
METRIC_DICT = {
"": "",
"ca": "character_accuracy",
"cer": "character_accuracy",
"wa": "word_accuracy",
"wer": "word_accuracy",
"boc": "bag_of_chars_accuracy",
"bow": "bag_of_words_accuracy",
}
@pytest.mark.integration @pytest.mark.integration
def test_cli_json(tmp_path): @pytest.mark.parametrize(
"""Test that the cli/process() yields a loadable JSON report""" "metrics",
[
*(("",), ("cer",), ("wer",), ("ca",), ("wa",), ("boc",), ("bow",)),
*combinations(("ca", "wa", "boc", "bow"), 2),
*combinations(("cer", "wer", "boc", "bow"), 2),
*combinations(("ca", "wa", "boc", "bow"), 3),
*combinations(("cer", "wer", "boc", "bow"), 3),
("ca", "wa", "boc", "bow"),
("cer", "wer", "boc", "bow"),
],
)
def test_cli_json(metrics, tmp_path):
"""Test that the cli/process() yields a loadable JSON report."""
expected_values = {
"character_accuracy": 0.2,
"word_accuracy": 1.0,
"bag_of_chars_accuracy": 0.2,
"bag_of_words_accuracy": 1.0,
}
with working_directory(str(tmp_path)): with working_directory(str(tmp_path)):
with open("gt.txt", "w") as gtf: with open("gt.txt", "w") as gtf:
@ -18,25 +47,38 @@ def test_cli_json(tmp_path):
with open("gt.txt", "r") as gtf: with open("gt.txt", "r") as gtf:
print(gtf.read()) print(gtf.read())
process("gt.txt", "ocr.txt", "report") process("gt.txt", "ocr.txt", "report", metrics=",".join(metrics))
with open("report.json", "r") as jsonf: with open("report.json", "r") as jsonf:
print(jsonf.read()) print(jsonf.read())
with open("report.json", "r") as jsonf: with open("report.json", "r") as jsonf:
j = json.load(jsonf) j = json.load(jsonf)
assert j["cer"] == pytest.approx(0.2) metrics_translated = {METRIC_DICT[metric] for metric in metrics}
for metric, expected_value in expected_values.items():
if metric in metrics_translated:
assert j[metric]["error_rate"] == pytest.approx(expected_value)
else:
assert metric not in j.keys()
@pytest.mark.integration @pytest.mark.integration
def test_cli_json_cer_is_infinity(tmp_path): @pytest.mark.parametrize(
"""Test that the cli/process() yields a loadable JSON report when CER == inf""" "gt,ocr,err",
[("", "Not important", float("inf")), ("Lorem Ipsum", "Lorem Ipsum", 0)],
)
def test_cli_json_extremes(gt, ocr, err, tmp_path):
"""Test that the cli/process() yields a loadable JSON reports."""
with working_directory(str(tmp_path)): with working_directory(str(tmp_path)):
with open("gt.txt", "w") as gtf: with open("gt.txt", "w") as gtf:
gtf.write("") # Empty to yield CER == inf gtf.write(gt)
with open("ocr.txt", "w") as ocrf: with open("ocr.txt", "w") as ocrf:
ocrf.write("Not important") ocrf.write(ocr)
process("gt.txt", "ocr.txt", "report") process("gt.txt", "ocr.txt", "report", metrics="ca,wa,boc,bow")
with open("report.json", "r") as jsonf: with open("report.json", "r") as jsonf:
j = json.load(jsonf) j = json.load(jsonf)
assert j["cer"] == pytest.approx(float("inf")) for metric in set(METRIC_DICT.values()):
if not metric:
continue
assert j[metric]["error_rate"] == pytest.approx(err)

@ -7,5 +7,4 @@ colorama
MarkupSafe MarkupSafe
ocrd >= 2.20.1 ocrd >= 2.20.1
attrs attrs
multimethod == 1.3 # latest version to officially support Python 3.5
tqdm tqdm

Loading…
Cancel
Save