mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-07 19:05:13 +02:00
Implemented new metrics behaviour
This commit is contained in:
parent
9f5112f8f6
commit
06468a436e
21 changed files with 296 additions and 242 deletions
19
README.md
19
README.md
|
@ -35,19 +35,22 @@ Usage: dinglehopper [OPTIONS] GT OCR [REPORT_PREFIX]
|
|||
their text and falls back to plain text if no ALTO or PAGE is detected.
|
||||
|
||||
The files GT and OCR are usually a ground truth document and the result of
|
||||
an OCR software, but you may use dinglehopper to compare two OCR results.
|
||||
In that case, use --no-metrics to disable the then meaningless metrics and
|
||||
also change the color scheme from green/red to blue.
|
||||
an OCR software, but you may use dinglehopper to compare two OCR results. In
|
||||
that case, use --metrics='' to disable the then meaningless metrics and also
|
||||
change the color scheme from green/red to blue.
|
||||
|
||||
The comparison report will be written to $REPORT_PREFIX.{html,json}, where
|
||||
$REPORT_PREFIX defaults to "report". The reports include the character
|
||||
error rate (CER) and the word error rate (WER).
|
||||
$REPORT_PREFIX defaults to "report". Depending on your configuration the
|
||||
reports include the character error rate (CA|CER), the word error rate (WA|WER),
|
||||
the bag of chars accuracy (BoC), and the bag of words accuracy (BoW).
|
||||
The metrics can be chosen via a comma separated combination of their acronyms
|
||||
like "--metrics=ca,wer,boc,bow".
|
||||
|
||||
By default, the text of PAGE files is extracted on 'region' level. You may
|
||||
use "--textequiv-level line" to extract from the level of TextLine tags.
|
||||
|
||||
Options:
|
||||
--metrics / --no-metrics Enable/disable metrics and green/red
|
||||
--metrics Enable different metrics like ca|cer, wa|wer, boc and bow.
|
||||
--textequiv-level LEVEL PAGE TextEquiv level to extract text from
|
||||
--progress Show progress bar
|
||||
--help Show this message and exit.
|
||||
|
@ -80,12 +83,12 @@ The OCR-D processor has these parameters:
|
|||
|
||||
| Parameter | Meaning |
|
||||
| ------------------------- | ------------------------------------------------------------------- |
|
||||
| `-P metrics false` | Disable metrics and the green-red color scheme (default: enabled) |
|
||||
| `-P metrics cer,wer` | Enable character error rate and word error rate (default) |
|
||||
| `-P textequiv_level line` | (PAGE) Extract text from TextLine level (default: TextRegion level) |
|
||||
|
||||
For example:
|
||||
~~~
|
||||
ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics false
|
||||
ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics cer,wer
|
||||
~~~
|
||||
|
||||
Developer information
|
||||
|
|
|
@ -10,9 +10,10 @@ from uniseg.graphemecluster import grapheme_clusters
|
|||
from .align import seq_align
|
||||
from .config import Config
|
||||
from .extracted_text import ExtractedText
|
||||
from .metrics.character_error_rate import character_error_rate_n
|
||||
from .metrics.word_error_rate import word_error_rate_n, words_normalized
|
||||
from .ocr_files import extract
|
||||
from .metrics import bag_of_chars_accuracy, bag_of_words_accuracy, character_accuracy, \
|
||||
word_accuracy
|
||||
from .normalize import words_normalized
|
||||
from .ocr_files import text
|
||||
|
||||
|
||||
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
|
||||
|
@ -85,9 +86,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
|
|||
)
|
||||
|
||||
|
||||
def generate_html_report(
|
||||
gt, ocr, gt_text, ocr_text, report_prefix, metrics, cer, n_characters, wer, n_words
|
||||
):
|
||||
def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results):
|
||||
char_diff_report = gen_diff_report(
|
||||
gt_text, ocr_text, css_prefix="c", joiner="", none="·"
|
||||
)
|
||||
|
@ -112,57 +111,50 @@ def generate_html_report(
|
|||
template.stream(
|
||||
gt=gt,
|
||||
ocr=ocr,
|
||||
cer=cer,
|
||||
n_characters=n_characters,
|
||||
wer=wer,
|
||||
n_words=n_words,
|
||||
char_diff_report=char_diff_report,
|
||||
word_diff_report=word_diff_report,
|
||||
metrics=metrics,
|
||||
metrics_results=metrics_results,
|
||||
).dump(out_fn)
|
||||
|
||||
|
||||
def generate_json_report(
|
||||
gt, ocr, report_prefix, metrics, cer, n_characters, wer, n_words
|
||||
):
|
||||
json_dict = {"gt": gt, "ocr": ocr, "n_characters": n_characters, "n_words": n_words}
|
||||
if metrics:
|
||||
json_dict = {**json_dict, "cer": cer, "wer": wer}
|
||||
with open(f"{report_prefix}.json", 'w') as fp:
|
||||
def generate_json_report(gt, ocr, report_prefix, metrics_results):
|
||||
json_dict = {"gt": gt, "ocr": ocr}
|
||||
for result in metrics_results:
|
||||
json_dict[result.metric] = {
|
||||
key: value for key, value in result.get_dict().items() if key != "metric"
|
||||
}
|
||||
print(json_dict)
|
||||
with open(f"{report_prefix}.json", "w") as fp:
|
||||
json.dump(json_dict, fp)
|
||||
|
||||
|
||||
def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"):
|
||||
def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="region"):
|
||||
"""Check OCR result against GT.
|
||||
|
||||
The @click decorators change the signature of the decorated functions,
|
||||
so we keep this undecorated version and use Click on a wrapper.
|
||||
"""
|
||||
|
||||
gt_text = extract(gt, textequiv_level=textequiv_level)
|
||||
ocr_text = extract(ocr, textequiv_level=textequiv_level)
|
||||
gt_text = text(gt, textequiv_level=textequiv_level)
|
||||
ocr_text = text(ocr, textequiv_level=textequiv_level)
|
||||
|
||||
cer, n_characters = character_error_rate_n(gt_text, ocr_text)
|
||||
wer, n_words = word_error_rate_n(gt_text, ocr_text)
|
||||
|
||||
generate_json_report(
|
||||
gt, ocr, report_prefix, metrics, cer, n_characters, wer, n_words
|
||||
)
|
||||
metrics_results = set()
|
||||
if metrics:
|
||||
metric_dict = {
|
||||
"ca": character_accuracy,
|
||||
"cer": character_accuracy,
|
||||
"wa": word_accuracy,
|
||||
"wer": word_accuracy,
|
||||
"boc": bag_of_chars_accuracy,
|
||||
"bow": bag_of_words_accuracy,
|
||||
}
|
||||
for metric in metrics.split(","):
|
||||
metrics_results.add(metric_dict[metric.strip()](gt_text, ocr_text))
|
||||
generate_json_report(gt, ocr, report_prefix, metrics_results)
|
||||
|
||||
html_report = True
|
||||
if html_report:
|
||||
generate_html_report(
|
||||
gt,
|
||||
ocr,
|
||||
gt_text,
|
||||
ocr_text,
|
||||
report_prefix,
|
||||
metrics,
|
||||
cer,
|
||||
n_characters,
|
||||
wer,
|
||||
n_words,
|
||||
)
|
||||
generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results)
|
||||
|
||||
|
||||
@click.command()
|
||||
|
@ -170,7 +162,9 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"):
|
|||
@click.argument("ocr", type=click.Path(exists=True))
|
||||
@click.argument("report_prefix", type=click.Path(), default="report")
|
||||
@click.option(
|
||||
"--metrics/--no-metrics", default=True, help="Enable/disable metrics and green/red"
|
||||
"--metrics",
|
||||
default="cer,wer",
|
||||
help="Enable different metrics like cer, wer, boc and bow.",
|
||||
)
|
||||
@click.option(
|
||||
"--textequiv-level",
|
||||
|
@ -188,12 +182,15 @@ def main(gt, ocr, report_prefix, metrics, textequiv_level, progress):
|
|||
|
||||
The files GT and OCR are usually a ground truth document and the result of
|
||||
an OCR software, but you may use dinglehopper to compare two OCR results. In
|
||||
that case, use --no-metrics to disable the then meaningless metrics and also
|
||||
that case, use --metrics='' to disable the then meaningless metrics and also
|
||||
change the color scheme from green/red to blue.
|
||||
|
||||
The comparison report will be written to $REPORT_PREFIX.{html,json}, where
|
||||
$REPORT_PREFIX defaults to "report". The reports include the character error
|
||||
rate (CER) and the word error rate (WER).
|
||||
$REPORT_PREFIX defaults to "report". Depending on your configuration the
|
||||
reports include the character error rate (CA|CER), the word error rate (WA|WER),
|
||||
the bag of chars accuracy (BoC), and the bag of words accuracy (BoW).
|
||||
The metrics can be chosen via a comma separated combination of their acronyms
|
||||
like "--metrics=ca,wer,boc,bow".
|
||||
|
||||
By default, the text of PAGE files is extracted on 'region' level. You may
|
||||
use "--textequiv-level line" to extract from the level of TextLine tags.
|
||||
|
|
|
@ -4,11 +4,9 @@ from functools import lru_cache, partial
|
|||
from typing import Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
from multimethod import multimethod
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import Config
|
||||
from .extracted_text import ExtractedText
|
||||
from .normalize import chars_normalized
|
||||
|
||||
|
||||
|
@ -74,7 +72,6 @@ def levenshtein_matrix_cache_clear():
|
|||
_levenshtein_matrix.cache_clear()
|
||||
|
||||
|
||||
@multimethod
|
||||
def distance(s1: str, s2: str):
|
||||
"""Compute the Levenshtein edit distance between two Unicode strings
|
||||
|
||||
|
@ -86,11 +83,6 @@ def distance(s1: str, s2: str):
|
|||
return levenshtein(seq1, seq2)
|
||||
|
||||
|
||||
@multimethod
|
||||
def distance(s1: ExtractedText, s2: ExtractedText):
|
||||
return distance(s1.text, s2.text)
|
||||
|
||||
|
||||
def seq_editops(seq1, seq2):
|
||||
"""
|
||||
Return sequence of edit operations transforming one sequence to another.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from .bag_of_chars_accuracy import *
|
||||
from .bag_of_words_accuracy import *
|
||||
from .character_error_rate import *
|
||||
from .utils import Weights
|
||||
from .word_error_rate import *
|
||||
from .character_accuracy import *
|
||||
from .utils import MetricResult, Weights
|
||||
from .word_accuracy import *
|
||||
|
|
|
@ -7,7 +7,7 @@ from .utils import bag_accuracy, MetricResult, Weights
|
|||
|
||||
|
||||
def bag_of_chars_accuracy(
|
||||
reference: str, compared: str, weights: Weights
|
||||
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
|
||||
) -> MetricResult:
|
||||
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
|
||||
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
|
||||
|
|
|
@ -5,7 +5,7 @@ from ..normalize import words_normalized
|
|||
|
||||
|
||||
def bag_of_words_accuracy(
|
||||
reference: str, compared: str, weights: Weights
|
||||
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
|
||||
) -> MetricResult:
|
||||
reference_words = Counter(words_normalized(reference))
|
||||
compared_words = Counter(words_normalized(compared))
|
||||
|
|
28
qurator/dinglehopper/metrics/character_accuracy.py
Normal file
28
qurator/dinglehopper/metrics/character_accuracy.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from __future__ import division
|
||||
|
||||
from .utils import MetricResult, Weights
|
||||
from .. import distance
|
||||
from ..normalize import chars_normalized
|
||||
|
||||
|
||||
def character_accuracy(
|
||||
reference: str, compared: str, weights: Weights = Weights(1, 1, 1)
|
||||
) -> MetricResult:
|
||||
"""Compute character accuracy and error rate.
|
||||
|
||||
:return: NamedTuple representing the results of this metric.
|
||||
"""
|
||||
|
||||
weighted_errors = distance(reference, compared)
|
||||
n_ref = len(chars_normalized(reference))
|
||||
n_cmp = len(chars_normalized(compared))
|
||||
|
||||
return MetricResult(
|
||||
metric=character_accuracy.__name__,
|
||||
weights=weights,
|
||||
weighted_errors=int(weighted_errors),
|
||||
reference_elements=n_ref,
|
||||
compared_elements=n_cmp,
|
||||
)
|
||||
|
||||
# XXX Should we really count newlines here?
|
|
@ -1,46 +0,0 @@
|
|||
from __future__ import division
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from multimethod import multimethod
|
||||
|
||||
from .. import distance
|
||||
from ..extracted_text import ExtractedText
|
||||
from ..normalize import chars_normalized
|
||||
|
||||
|
||||
@multimethod
|
||||
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
||||
"""
|
||||
Compute character error rate.
|
||||
|
||||
:return: character error rate and length of the reference
|
||||
"""
|
||||
|
||||
d = distance(reference, compared)
|
||||
n = len(chars_normalized(reference))
|
||||
|
||||
if d == 0:
|
||||
return 0, n
|
||||
if n == 0:
|
||||
return float("inf"), n
|
||||
return d / n, n
|
||||
|
||||
# XXX Should we really count newlines here?
|
||||
|
||||
|
||||
@multimethod
|
||||
def character_error_rate_n(
|
||||
reference: ExtractedText, compared: ExtractedText
|
||||
) -> Tuple[float, int]:
|
||||
return character_error_rate_n(reference.text, compared.text)
|
||||
|
||||
|
||||
def character_error_rate(reference, compared) -> float:
|
||||
"""
|
||||
Compute character error rate.
|
||||
|
||||
:return: character error rate
|
||||
"""
|
||||
cer, _ = character_error_rate_n(reference, compared)
|
||||
return cer
|
|
@ -1,5 +1,5 @@
|
|||
from collections import Counter
|
||||
from typing import NamedTuple
|
||||
from typing import Dict, NamedTuple
|
||||
|
||||
|
||||
class Weights(NamedTuple):
|
||||
|
@ -25,10 +25,27 @@ class MetricResult(NamedTuple):
|
|||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
if self.reference_elements <= 0:
|
||||
if self.reference_elements <= 0 and self.compared_elements <= 0:
|
||||
return 0
|
||||
elif self.reference_elements <= 0:
|
||||
return float("inf")
|
||||
return self.weighted_errors / self.reference_elements
|
||||
|
||||
def get_dict(self) -> Dict:
|
||||
"""Combines the properties to a dictionary.
|
||||
|
||||
We deviate from the builtin _asdict() function by including our properties.
|
||||
"""
|
||||
return {
|
||||
**{
|
||||
key: value
|
||||
for key, value in self._asdict().items()
|
||||
},
|
||||
"accuracy": self.accuracy,
|
||||
"error_rate": self.error_rate,
|
||||
"weights": self.weights._asdict(),
|
||||
}
|
||||
|
||||
|
||||
def bag_accuracy(
|
||||
reference: Counter, compared: Counter, weights: Weights
|
||||
|
@ -44,7 +61,7 @@ def bag_accuracy(
|
|||
:param reference: Bag used as reference (ground truth).
|
||||
:param compared: Bag used to compare (ocr).
|
||||
:param weights: Weights/costs for editing operations.
|
||||
:return: Tuple representing the results of this metric.
|
||||
:return: NamedTuple representing the results of this metric.
|
||||
"""
|
||||
n_ref = sum(reference.values())
|
||||
n_cmp = sum(compared.values())
|
||||
|
|
22
qurator/dinglehopper/metrics/word_accuracy.py
Normal file
22
qurator/dinglehopper/metrics/word_accuracy.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
from .utils import MetricResult, Weights
|
||||
from ..edit_distance import levenshtein
|
||||
from ..normalize import words_normalized
|
||||
|
||||
|
||||
def word_accuracy(
|
||||
reference: str, compared: str, weights: Weights = Weights(1, 1, 1)
|
||||
) -> MetricResult:
|
||||
reference_seq = list(words_normalized(reference))
|
||||
compared_seq = list(words_normalized(compared))
|
||||
|
||||
weighted_errors = levenshtein(reference_seq, compared_seq)
|
||||
n_ref = len(reference_seq)
|
||||
n_cmp = len(compared_seq)
|
||||
|
||||
return MetricResult(
|
||||
metric=word_accuracy.__name__,
|
||||
weights=weights,
|
||||
weighted_errors=int(weighted_errors),
|
||||
reference_elements=n_ref,
|
||||
compared_elements=n_cmp,
|
||||
)
|
|
@ -1,43 +0,0 @@
|
|||
from __future__ import division
|
||||
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
from multimethod import multimethod
|
||||
|
||||
from ..edit_distance import levenshtein
|
||||
from ..extracted_text import ExtractedText
|
||||
from ..normalize import words_normalized
|
||||
|
||||
|
||||
@multimethod
|
||||
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
||||
reference_seq = list(words_normalized(reference))
|
||||
compared_seq = list(words_normalized(compared))
|
||||
return word_error_rate_n(reference_seq, compared_seq)
|
||||
|
||||
|
||||
@multimethod
|
||||
def word_error_rate_n(
|
||||
reference: ExtractedText, compared: ExtractedText
|
||||
) -> Tuple[float, int]:
|
||||
return word_error_rate_n(reference.text, compared.text)
|
||||
|
||||
|
||||
@multimethod
|
||||
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
|
||||
reference_seq = list(reference)
|
||||
compared_seq = list(compared)
|
||||
|
||||
d = levenshtein(reference_seq, compared_seq)
|
||||
n = len(reference_seq)
|
||||
|
||||
if d == 0:
|
||||
return 0, n
|
||||
if n == 0:
|
||||
return float("inf"), n
|
||||
return d / n, n
|
||||
|
||||
|
||||
def word_error_rate(reference, compared) -> float:
|
||||
wer, _ = word_error_rate_n(reference, compared)
|
||||
return wer
|
|
@ -163,8 +163,8 @@ def extract(filename, *, textequiv_level="region"):
|
|||
return alto_extract(tree)
|
||||
|
||||
|
||||
def text(filename):
|
||||
return extract(filename).text
|
||||
def text(filename, *args, **kwargs):
|
||||
return extract(filename, *args, **kwargs).text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -19,9 +19,10 @@
|
|||
],
|
||||
"parameters": {
|
||||
"metrics": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Enable/disable metrics and green/red"
|
||||
"type": "string",
|
||||
"enum": ["", "boc", "boc,bow", "bow", "ca", "ca,boc", "ca,boc,bow", "ca,bow", "ca,wa", "ca,wa,boc", "ca,wa,boc,bow", "ca,wa,bow", "ca,wer", "ca,wer,boc", "ca,wer,boc,bow", "ca,wer,bow", "cer", "cer,boc", "cer,boc,bow", "cer,bow", "cer,wa", "cer,wa,boc", "cer,wa,boc,bow", "cer,wa,bow", "cer,wer", "cer,wer,boc", "cer,wer,boc,bow", "cer,wer,bow", "wa", "wa,boc", "wa,boc,bow", "wa,bow", "wer", "wer,boc", "wer,boc,bow", "wer,bow"],
|
||||
"default": "cer,wer",
|
||||
"description": "Enable different metrics like ca|cer, wa|wer, boc and bow."
|
||||
},
|
||||
"textequiv_level": {
|
||||
"type": "string",
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
|
||||
<style type="text/css">
|
||||
{% if metrics %}
|
||||
{% if metrics_results %}
|
||||
.gt .diff {
|
||||
color: green;
|
||||
}
|
||||
|
@ -38,10 +38,44 @@
|
|||
{{ ocr }}
|
||||
|
||||
|
||||
{% if metrics %}
|
||||
{% if metrics_results %}
|
||||
<h2>Metrics</h2>
|
||||
<p>CER: {{ cer|round(4) }}</p>
|
||||
<p>WER: {{ wer|round(4) }}</p>
|
||||
<table class="table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th scope="col" rowspan="2">#</th>
|
||||
<th scope="col" colspan="2">Results</th>
|
||||
<th scope="col" colspan="2">Elements</th>
|
||||
<th scope="col" colspan="2">Calculation</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th scope="col">Acc</th>
|
||||
<th scope="col">ER</th>
|
||||
<th scope="col">Ref</th>
|
||||
<th scope="col">Cmp</th>
|
||||
<th scope="col">Errors</th>
|
||||
<th scope="col">Weights</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for result in metrics_results %}
|
||||
<tr>
|
||||
<th scope="row">{{ result.metric.replace("_", " ") }}</th>
|
||||
<td>{{ result.accuracy }}</td>
|
||||
<td>{{ result.error_rate }}</td>
|
||||
<td>{{ result.reference_elements }}</td>
|
||||
<td>{{ result.compared_elements }}</td>
|
||||
<td>{{ result.weighted_errors }}</td>
|
||||
<td>d={{ result.weights.deletes }},i={{ result.weights.inserts }},r={{ result.weights.replacements }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
<tfoot>
|
||||
<tr>
|
||||
<td colspan="7">Legend: Acc. = Accuracy, ER = Error Rate, Ref. = Reference, Cmp. = Compared</td>
|
||||
</tr>
|
||||
</tfoot>
|
||||
</table>
|
||||
{% endif %}
|
||||
|
||||
<h2>Character differences</h2>
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
import math
|
||||
import unicodedata
|
||||
|
||||
from ...metrics import character_accuracy
|
||||
|
||||
|
||||
def test_character_accuracy():
|
||||
assert character_accuracy("a", "a").error_rate == 0
|
||||
assert character_accuracy("a", "b").error_rate == 1 / 1
|
||||
assert character_accuracy("Foo", "Bar").error_rate == 3 / 3
|
||||
|
||||
assert character_accuracy("Foo", "").error_rate == 3 / 3
|
||||
|
||||
assert character_accuracy("", "").error_rate == 0
|
||||
assert math.isinf(character_accuracy("", "Foo").error_rate)
|
||||
|
||||
assert character_accuracy("Foo", "Food").error_rate == 1 / 3
|
||||
assert character_accuracy("Fnord", "Food").error_rate == 2 / 5
|
||||
assert character_accuracy("Müll", "Mull").error_rate == 1 / 4
|
||||
assert character_accuracy("Abstand", "Sand").error_rate == 4 / 7
|
||||
|
||||
|
||||
def test_character_accuracy_hard():
|
||||
s1 = unicodedata.normalize("NFC", "Schlyñ lorem ipsum.")
|
||||
s2 = unicodedata.normalize("NFD", "Schlyñ lorem ipsum!") # Different, decomposed!
|
||||
assert character_accuracy(s1, s2).error_rate == 1 / 19
|
||||
|
||||
s1 = "Schlyñ"
|
||||
assert (
|
||||
len(s1) == 6
|
||||
) # This ends with LATIN SMALL LETTER N WITH TILDE, so 6 code points
|
||||
s2 = "Schlym̃"
|
||||
assert (
|
||||
len(s2) == 7
|
||||
) # This, OTOH, ends with LATIN SMALL LETTER M + COMBINING TILDE, 7 code points
|
||||
|
||||
# Both strings have the same length in terms of grapheme clusters.
|
||||
# So the CER should be symmetrical.
|
||||
assert character_accuracy(s2, s1).error_rate == 1 / 6
|
||||
assert character_accuracy(s1, s2).error_rate == 1 / 6
|
|
@ -1,41 +0,0 @@
|
|||
from __future__ import division, print_function
|
||||
|
||||
import math
|
||||
import unicodedata
|
||||
|
||||
from ...metrics import character_error_rate
|
||||
|
||||
|
||||
def test_character_error_rate():
|
||||
assert character_error_rate("a", "a") == 0
|
||||
assert character_error_rate("a", "b") == 1 / 1
|
||||
assert character_error_rate("Foo", "Bar") == 3 / 3
|
||||
|
||||
assert character_error_rate("Foo", "") == 3 / 3
|
||||
|
||||
assert character_error_rate("", "") == 0
|
||||
assert math.isinf(character_error_rate("", "Foo"))
|
||||
|
||||
assert character_error_rate("Foo", "Food") == 1 / 3
|
||||
assert character_error_rate("Fnord", "Food") == 2 / 5
|
||||
assert character_error_rate("Müll", "Mull") == 1 / 4
|
||||
assert character_error_rate("Abstand", "Sand") == 4 / 7
|
||||
|
||||
|
||||
def test_character_error_rate_hard():
|
||||
s1 = unicodedata.normalize("NFC", "Schlyñ lorem ipsum.")
|
||||
s2 = unicodedata.normalize("NFD", "Schlyñ lorem ipsum!") # Different, decomposed!
|
||||
assert character_error_rate(s1, s2) == 1 / 19
|
||||
|
||||
s1 = "Schlyñ"
|
||||
assert (
|
||||
len(s1) == 6
|
||||
) # This ends with LATIN SMALL LETTER N WITH TILDE, so 6 code points
|
||||
s2 = "Schlym̃"
|
||||
assert (
|
||||
len(s2) == 7
|
||||
) # This, OTOH, ends with LATIN SMALL LETTER M + COMBINING TILDE, 7 code points
|
||||
|
||||
# Both strings have the same length in terms of grapheme clusters. So the CER should be symmetrical.
|
||||
assert character_error_rate(s2, s1) == 1 / 6
|
||||
assert character_error_rate(s1, s2) == 1 / 6
|
|
@ -7,7 +7,7 @@ from lxml import etree as ET
|
|||
from uniseg.graphemecluster import grapheme_clusters
|
||||
|
||||
from ... import page_text, alto_text
|
||||
from ...metrics import character_error_rate
|
||||
from ...metrics import character_accuracy
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
|
||||
|
||||
|
@ -22,7 +22,7 @@ def test_character_error_rate_between_page_files():
|
|||
gt_len = len(list(grapheme_clusters(gt)))
|
||||
expected_cer = 2 / gt_len
|
||||
|
||||
assert character_error_rate(gt, ocr) == expected_cer
|
||||
assert character_accuracy(gt, ocr).error_rate == expected_cer
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
@ -39,7 +39,7 @@ def test_character_error_rate_between_page_alto():
|
|||
)
|
||||
|
||||
assert gt == ocr
|
||||
assert character_error_rate(gt, ocr) == 0
|
||||
assert character_accuracy(gt, ocr).error_rate == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
@ -57,4 +57,4 @@ def test_character_error_rate_between_page_alto_2():
|
|||
)
|
||||
)
|
||||
|
||||
assert character_error_rate(gt, ocr) == 8 / 591 # Manually verified
|
||||
assert character_accuracy(gt, ocr).error_rate == 8 / 591 # Manually verified
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from lxml import etree as ET
|
||||
|
||||
from ... import alto_text, page_text
|
||||
from ...metrics import word_error_rate
|
||||
from ...metrics import word_accuracy
|
||||
from ...normalize import words
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
|
||||
|
@ -24,7 +24,7 @@ def test_word_error_rate_between_page_files():
|
|||
assert len(list(words(gt))) == gt_word_count
|
||||
|
||||
ocr = page_text(ET.parse(os.path.join(data_dir, "test-fake-ocr.page2018.xml")))
|
||||
assert word_error_rate(gt, ocr) == 2 / gt_word_count
|
||||
assert word_accuracy(gt, ocr).error_rate == 2 / gt_word_count
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
@ -41,7 +41,7 @@ def test_word_error_rate_between_page_alto():
|
|||
)
|
||||
|
||||
assert gt == ocr
|
||||
assert word_error_rate(gt, ocr) == 0
|
||||
assert word_accuracy(gt, ocr).error_rate == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
@ -66,5 +66,5 @@ def test_word_error_rate_between_page_alto_2():
|
|||
)
|
||||
|
||||
assert (
|
||||
word_error_rate(gt, ocr) == 7 / gt_word_count
|
||||
word_accuracy(gt, ocr).error_rate == 7 / gt_word_count
|
||||
) # Manually verified, 6 words are wrong, 1 got split (=2 errors)
|
|
@ -1,8 +1,6 @@
|
|||
from __future__ import division, print_function
|
||||
|
||||
import math
|
||||
|
||||
from ...metrics import word_error_rate
|
||||
from ...metrics import word_accuracy
|
||||
from ...normalize import words
|
||||
|
||||
|
||||
|
@ -55,33 +53,44 @@ def test_words_private_use_area():
|
|||
|
||||
def test_word_error_rate():
|
||||
assert (
|
||||
word_error_rate("Dies ist ein Beispielsatz!", "Dies ist ein Beispielsatz!") == 0
|
||||
)
|
||||
assert (
|
||||
word_error_rate("Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz!")
|
||||
word_accuracy(
|
||||
"Dies ist ein Beispielsatz!", "Dies ist ein Beispielsatz!"
|
||||
).error_rate
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
word_error_rate("Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz.")
|
||||
word_accuracy(
|
||||
"Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz!"
|
||||
).error_rate
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
word_accuracy(
|
||||
"Dies. ist ein Beispielsatz!", "Dies ist ein Beispielsatz."
|
||||
).error_rate
|
||||
== 0
|
||||
)
|
||||
|
||||
assert (
|
||||
word_error_rate("Dies ist ein Beispielsatz!", "Dies ist ein Beispielsarz:")
|
||||
word_accuracy(
|
||||
"Dies ist ein Beispielsatz!", "Dies ist ein Beispielsarz:"
|
||||
).error_rate
|
||||
== 1 / 4
|
||||
)
|
||||
assert (
|
||||
word_error_rate("Dies ist ein Beispielsatz!", "Dies ein ist Beispielsatz!")
|
||||
word_accuracy(
|
||||
"Dies ist ein Beispielsatz!", "Dies ein ist Beispielsatz!"
|
||||
).error_rate
|
||||
== 2 / 4
|
||||
)
|
||||
|
||||
assert word_error_rate("Dies ist ein Beispielsatz!", "") == 4 / 4
|
||||
assert math.isinf(word_error_rate("", "Dies ist ein Beispielsatz!"))
|
||||
assert word_error_rate("", "") == 0
|
||||
assert word_accuracy("Dies ist ein Beispielsatz!", "").error_rate == 4 / 4
|
||||
assert math.isinf(word_accuracy("", "Dies ist ein Beispielsatz!").error_rate)
|
||||
assert word_accuracy("", "").error_rate == 0
|
||||
|
||||
assert (
|
||||
word_error_rate(
|
||||
word_accuracy(
|
||||
"Schlyñ lorem ipsum dolor sit amet,", "Schlym̃ lorem ipsum dolor sit amet."
|
||||
)
|
||||
).error_rate
|
||||
== 1 / 6
|
||||
)
|
|
@ -1,14 +1,43 @@
|
|||
import json
|
||||
from itertools import combinations
|
||||
|
||||
import pytest
|
||||
from .util import working_directory
|
||||
|
||||
from .util import working_directory
|
||||
from ..cli import process
|
||||
|
||||
METRIC_DICT = {
|
||||
"": "",
|
||||
"ca": "character_accuracy",
|
||||
"cer": "character_accuracy",
|
||||
"wa": "word_accuracy",
|
||||
"wer": "word_accuracy",
|
||||
"boc": "bag_of_chars_accuracy",
|
||||
"bow": "bag_of_words_accuracy",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_cli_json(tmp_path):
|
||||
"""Test that the cli/process() yields a loadable JSON report"""
|
||||
@pytest.mark.parametrize(
|
||||
"metrics",
|
||||
[
|
||||
*(("",), ("cer",), ("wer",), ("ca",), ("wa",), ("boc",), ("bow",)),
|
||||
*combinations(("ca", "wa", "boc", "bow"), 2),
|
||||
*combinations(("cer", "wer", "boc", "bow"), 2),
|
||||
*combinations(("ca", "wa", "boc", "bow"), 3),
|
||||
*combinations(("cer", "wer", "boc", "bow"), 3),
|
||||
("ca", "wa", "boc", "bow"),
|
||||
("cer", "wer", "boc", "bow"),
|
||||
],
|
||||
)
|
||||
def test_cli_json(metrics, tmp_path):
|
||||
"""Test that the cli/process() yields a loadable JSON report."""
|
||||
expected_values = {
|
||||
"character_accuracy": 0.2,
|
||||
"word_accuracy": 1.0,
|
||||
"bag_of_chars_accuracy": 0.2,
|
||||
"bag_of_words_accuracy": 1.0,
|
||||
}
|
||||
|
||||
with working_directory(str(tmp_path)):
|
||||
with open("gt.txt", "w") as gtf:
|
||||
|
@ -18,25 +47,38 @@ def test_cli_json(tmp_path):
|
|||
|
||||
with open("gt.txt", "r") as gtf:
|
||||
print(gtf.read())
|
||||
process("gt.txt", "ocr.txt", "report")
|
||||
process("gt.txt", "ocr.txt", "report", metrics=",".join(metrics))
|
||||
with open("report.json", "r") as jsonf:
|
||||
print(jsonf.read())
|
||||
|
||||
with open("report.json", "r") as jsonf:
|
||||
j = json.load(jsonf)
|
||||
assert j["cer"] == pytest.approx(0.2)
|
||||
metrics_translated = {METRIC_DICT[metric] for metric in metrics}
|
||||
for metric, expected_value in expected_values.items():
|
||||
if metric in metrics_translated:
|
||||
assert j[metric]["error_rate"] == pytest.approx(expected_value)
|
||||
else:
|
||||
assert metric not in j.keys()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_cli_json_cer_is_infinity(tmp_path):
|
||||
"""Test that the cli/process() yields a loadable JSON report when CER == inf"""
|
||||
@pytest.mark.parametrize(
|
||||
"gt,ocr,err",
|
||||
[("", "Not important", float("inf")), ("Lorem Ipsum", "Lorem Ipsum", 0)],
|
||||
)
|
||||
def test_cli_json_extremes(gt, ocr, err, tmp_path):
|
||||
"""Test that the cli/process() yields a loadable JSON reports."""
|
||||
|
||||
with working_directory(str(tmp_path)):
|
||||
with open("gt.txt", "w") as gtf:
|
||||
gtf.write("") # Empty to yield CER == inf
|
||||
gtf.write(gt)
|
||||
with open("ocr.txt", "w") as ocrf:
|
||||
ocrf.write("Not important")
|
||||
ocrf.write(ocr)
|
||||
|
||||
process("gt.txt", "ocr.txt", "report")
|
||||
process("gt.txt", "ocr.txt", "report", metrics="ca,wa,boc,bow")
|
||||
with open("report.json", "r") as jsonf:
|
||||
j = json.load(jsonf)
|
||||
assert j["cer"] == pytest.approx(float("inf"))
|
||||
for metric in set(METRIC_DICT.values()):
|
||||
if not metric:
|
||||
continue
|
||||
assert j[metric]["error_rate"] == pytest.approx(err)
|
||||
|
|
|
@ -7,5 +7,4 @@ colorama
|
|||
MarkupSafe
|
||||
ocrd >= 2.20.1
|
||||
attrs
|
||||
multimethod == 1.3 # latest version to officially support Python 3.5
|
||||
tqdm
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue