diff --git a/README.md b/README.md index 6d82541..8a50cc3 100644 --- a/README.md +++ b/README.md @@ -34,20 +34,23 @@ Usage: dinglehopper [OPTIONS] GT OCR [REPORT_PREFIX] dinglehopper detects if GT/OCR are ALTO or PAGE XML documents to extract their text and falls back to plain text if no ALTO or PAGE is detected. - The files GT and OCR are usually a ground truth document and the result of - an OCR software, but you may use dinglehopper to compare two OCR results. - In that case, use --no-metrics to disable the then meaningless metrics and - also change the color scheme from green/red to blue. + The files GT and OCR are usually a ground truth document and the result of + an OCR software, but you may use dinglehopper to compare two OCR results. In + that case, use --metrics='' to disable the then meaningless metrics and also + change the color scheme from green/red to blue. The comparison report will be written to $REPORT_PREFIX.{html,json}, where - $REPORT_PREFIX defaults to "report". The reports include the character - error rate (CER) and the word error rate (WER). + $REPORT_PREFIX defaults to "report". Depending on your configuration the + reports include the character error rate (CA|CER), the word error rate (WA|WER), + the bag of chars accuracy (BoC), and the bag of words accuracy (BoW). + The metrics can be chosen via a comma separated combination of their acronyms + like "--metrics=ca,wer,boc,bow". By default, the text of PAGE files is extracted on 'region' level. You may use "--textequiv-level line" to extract from the level of TextLine tags. Options: - --metrics / --no-metrics Enable/disable metrics and green/red + --metrics Enable different metrics like ca|cer, wa|wer, boc and bow. --textequiv-level LEVEL PAGE TextEquiv level to extract text from --progress Show progress bar --help Show this message and exit. @@ -80,12 +83,12 @@ The OCR-D processor has these parameters: | Parameter | Meaning | | ------------------------- | ------------------------------------------------------------------- | -| `-P metrics false` | Disable metrics and the green-red color scheme (default: enabled) | +| `-P metrics cer,wer` | Enable character error rate and word error rate (default) | | `-P textequiv_level line` | (PAGE) Extract text from TextLine level (default: TextRegion level) | For example: ~~~ -ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics false +ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics cer,wer ~~~ Developer information diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py index 5e5b5a8..22ba9f5 100644 --- a/qurator/dinglehopper/cli.py +++ b/qurator/dinglehopper/cli.py @@ -10,9 +10,10 @@ from uniseg.graphemecluster import grapheme_clusters from .align import seq_align from .config import Config from .extracted_text import ExtractedText -from .metrics.character_error_rate import character_error_rate_n -from .metrics.word_error_rate import word_error_rate_n, words_normalized -from .ocr_files import extract +from .metrics import bag_of_chars_accuracy, bag_of_words_accuracy, character_accuracy, \ + word_accuracy +from .normalize import words_normalized +from .ocr_files import text def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): @@ -85,9 +86,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): ) -def generate_html_report( - gt, ocr, gt_text, ocr_text, report_prefix, metrics, cer, n_characters, wer, n_words -): +def generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results): char_diff_report = gen_diff_report( gt_text, ocr_text, css_prefix="c", joiner="", none="·" ) @@ -112,57 +111,50 @@ def generate_html_report( template.stream( gt=gt, ocr=ocr, - cer=cer, - n_characters=n_characters, - wer=wer, - n_words=n_words, char_diff_report=char_diff_report, word_diff_report=word_diff_report, - metrics=metrics, + metrics_results=metrics_results, ).dump(out_fn) -def generate_json_report( - gt, ocr, report_prefix, metrics, cer, n_characters, wer, n_words -): - json_dict = {"gt": gt, "ocr": ocr, "n_characters": n_characters, "n_words": n_words} - if metrics: - json_dict = {**json_dict, "cer": cer, "wer": wer} - with open(f"{report_prefix}.json", 'w') as fp: +def generate_json_report(gt, ocr, report_prefix, metrics_results): + json_dict = {"gt": gt, "ocr": ocr} + for result in metrics_results: + json_dict[result.metric] = { + key: value for key, value in result.get_dict().items() if key != "metric" + } + print(json_dict) + with open(f"{report_prefix}.json", "w") as fp: json.dump(json_dict, fp) -def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"): +def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="region"): """Check OCR result against GT. The @click decorators change the signature of the decorated functions, so we keep this undecorated version and use Click on a wrapper. """ - gt_text = extract(gt, textequiv_level=textequiv_level) - ocr_text = extract(ocr, textequiv_level=textequiv_level) - - cer, n_characters = character_error_rate_n(gt_text, ocr_text) - wer, n_words = word_error_rate_n(gt_text, ocr_text) + gt_text = text(gt, textequiv_level=textequiv_level) + ocr_text = text(ocr, textequiv_level=textequiv_level) - generate_json_report( - gt, ocr, report_prefix, metrics, cer, n_characters, wer, n_words - ) + metrics_results = set() + if metrics: + metric_dict = { + "ca": character_accuracy, + "cer": character_accuracy, + "wa": word_accuracy, + "wer": word_accuracy, + "boc": bag_of_chars_accuracy, + "bow": bag_of_words_accuracy, + } + for metric in metrics.split(","): + metrics_results.add(metric_dict[metric.strip()](gt_text, ocr_text)) + generate_json_report(gt, ocr, report_prefix, metrics_results) html_report = True if html_report: - generate_html_report( - gt, - ocr, - gt_text, - ocr_text, - report_prefix, - metrics, - cer, - n_characters, - wer, - n_words, - ) + generate_html_report(gt, ocr, gt_text, ocr_text, report_prefix, metrics_results) @click.command() @@ -170,7 +162,9 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"): @click.argument("ocr", type=click.Path(exists=True)) @click.argument("report_prefix", type=click.Path(), default="report") @click.option( - "--metrics/--no-metrics", default=True, help="Enable/disable metrics and green/red" + "--metrics", + default="cer,wer", + help="Enable different metrics like cer, wer, boc and bow.", ) @click.option( "--textequiv-level", @@ -188,12 +182,15 @@ def main(gt, ocr, report_prefix, metrics, textequiv_level, progress): The files GT and OCR are usually a ground truth document and the result of an OCR software, but you may use dinglehopper to compare two OCR results. In - that case, use --no-metrics to disable the then meaningless metrics and also + that case, use --metrics='' to disable the then meaningless metrics and also change the color scheme from green/red to blue. The comparison report will be written to $REPORT_PREFIX.{html,json}, where - $REPORT_PREFIX defaults to "report". The reports include the character error - rate (CER) and the word error rate (WER). + $REPORT_PREFIX defaults to "report". Depending on your configuration the + reports include the character error rate (CA|CER), the word error rate (WA|WER), + the bag of chars accuracy (BoC), and the bag of words accuracy (BoW). + The metrics can be chosen via a comma separated combination of their acronyms + like "--metrics=ca,wer,boc,bow". By default, the text of PAGE files is extracted on 'region' level. You may use "--textequiv-level line" to extract from the level of TextLine tags. diff --git a/qurator/dinglehopper/edit_distance.py b/qurator/dinglehopper/edit_distance.py index 6c459fa..798419c 100644 --- a/qurator/dinglehopper/edit_distance.py +++ b/qurator/dinglehopper/edit_distance.py @@ -4,11 +4,9 @@ from functools import lru_cache, partial from typing import Sequence, Tuple import numpy as np -from multimethod import multimethod from tqdm import tqdm from .config import Config -from .extracted_text import ExtractedText from .normalize import chars_normalized @@ -74,7 +72,6 @@ def levenshtein_matrix_cache_clear(): _levenshtein_matrix.cache_clear() -@multimethod def distance(s1: str, s2: str): """Compute the Levenshtein edit distance between two Unicode strings @@ -86,11 +83,6 @@ def distance(s1: str, s2: str): return levenshtein(seq1, seq2) -@multimethod -def distance(s1: ExtractedText, s2: ExtractedText): - return distance(s1.text, s2.text) - - def seq_editops(seq1, seq2): """ Return sequence of edit operations transforming one sequence to another. diff --git a/qurator/dinglehopper/metrics/__init__.py b/qurator/dinglehopper/metrics/__init__.py index ba9d140..07bb077 100644 --- a/qurator/dinglehopper/metrics/__init__.py +++ b/qurator/dinglehopper/metrics/__init__.py @@ -1,5 +1,5 @@ from .bag_of_chars_accuracy import * from .bag_of_words_accuracy import * -from .character_error_rate import * -from .utils import Weights -from .word_error_rate import * +from .character_accuracy import * +from .utils import MetricResult, Weights +from .word_accuracy import * diff --git a/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py index c9cd9f2..79ed34e 100644 --- a/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py +++ b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py @@ -7,7 +7,7 @@ from .utils import bag_accuracy, MetricResult, Weights def bag_of_chars_accuracy( - reference: str, compared: str, weights: Weights + reference: str, compared: str, weights: Weights = Weights(1, 0, 1) ) -> MetricResult: reference_chars = Counter(grapheme_clusters(normalize("NFC", reference))) compared_chars = Counter(grapheme_clusters(normalize("NFC", compared))) diff --git a/qurator/dinglehopper/metrics/bag_of_words_accuracy.py b/qurator/dinglehopper/metrics/bag_of_words_accuracy.py index bef86c1..f2e0a88 100644 --- a/qurator/dinglehopper/metrics/bag_of_words_accuracy.py +++ b/qurator/dinglehopper/metrics/bag_of_words_accuracy.py @@ -5,7 +5,7 @@ from ..normalize import words_normalized def bag_of_words_accuracy( - reference: str, compared: str, weights: Weights + reference: str, compared: str, weights: Weights = Weights(1, 0, 1) ) -> MetricResult: reference_words = Counter(words_normalized(reference)) compared_words = Counter(words_normalized(compared)) diff --git a/qurator/dinglehopper/metrics/character_accuracy.py b/qurator/dinglehopper/metrics/character_accuracy.py new file mode 100644 index 0000000..1e89a76 --- /dev/null +++ b/qurator/dinglehopper/metrics/character_accuracy.py @@ -0,0 +1,28 @@ +from __future__ import division + +from .utils import MetricResult, Weights +from .. import distance +from ..normalize import chars_normalized + + +def character_accuracy( + reference: str, compared: str, weights: Weights = Weights(1, 1, 1) +) -> MetricResult: + """Compute character accuracy and error rate. + + :return: NamedTuple representing the results of this metric. + """ + + weighted_errors = distance(reference, compared) + n_ref = len(chars_normalized(reference)) + n_cmp = len(chars_normalized(compared)) + + return MetricResult( + metric=character_accuracy.__name__, + weights=weights, + weighted_errors=int(weighted_errors), + reference_elements=n_ref, + compared_elements=n_cmp, + ) + + # XXX Should we really count newlines here? diff --git a/qurator/dinglehopper/metrics/character_error_rate.py b/qurator/dinglehopper/metrics/character_error_rate.py deleted file mode 100644 index 0e40c66..0000000 --- a/qurator/dinglehopper/metrics/character_error_rate.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import division - -from typing import Tuple - -from multimethod import multimethod - -from .. import distance -from ..extracted_text import ExtractedText -from ..normalize import chars_normalized - - -@multimethod -def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: - """ - Compute character error rate. - - :return: character error rate and length of the reference - """ - - d = distance(reference, compared) - n = len(chars_normalized(reference)) - - if d == 0: - return 0, n - if n == 0: - return float("inf"), n - return d / n, n - - # XXX Should we really count newlines here? - - -@multimethod -def character_error_rate_n( - reference: ExtractedText, compared: ExtractedText -) -> Tuple[float, int]: - return character_error_rate_n(reference.text, compared.text) - - -def character_error_rate(reference, compared) -> float: - """ - Compute character error rate. - - :return: character error rate - """ - cer, _ = character_error_rate_n(reference, compared) - return cer diff --git a/qurator/dinglehopper/metrics/utils.py b/qurator/dinglehopper/metrics/utils.py index b3ca5bc..179f462 100644 --- a/qurator/dinglehopper/metrics/utils.py +++ b/qurator/dinglehopper/metrics/utils.py @@ -1,5 +1,5 @@ from collections import Counter -from typing import NamedTuple +from typing import Dict, NamedTuple class Weights(NamedTuple): @@ -25,10 +25,27 @@ class MetricResult(NamedTuple): @property def error_rate(self) -> float: - if self.reference_elements <= 0: + if self.reference_elements <= 0 and self.compared_elements <= 0: + return 0 + elif self.reference_elements <= 0: return float("inf") return self.weighted_errors / self.reference_elements + def get_dict(self) -> Dict: + """Combines the properties to a dictionary. + + We deviate from the builtin _asdict() function by including our properties. + """ + return { + **{ + key: value + for key, value in self._asdict().items() + }, + "accuracy": self.accuracy, + "error_rate": self.error_rate, + "weights": self.weights._asdict(), + } + def bag_accuracy( reference: Counter, compared: Counter, weights: Weights @@ -44,7 +61,7 @@ def bag_accuracy( :param reference: Bag used as reference (ground truth). :param compared: Bag used to compare (ocr). :param weights: Weights/costs for editing operations. - :return: Tuple representing the results of this metric. + :return: NamedTuple representing the results of this metric. """ n_ref = sum(reference.values()) n_cmp = sum(compared.values()) diff --git a/qurator/dinglehopper/metrics/word_accuracy.py b/qurator/dinglehopper/metrics/word_accuracy.py new file mode 100644 index 0000000..cfebe31 --- /dev/null +++ b/qurator/dinglehopper/metrics/word_accuracy.py @@ -0,0 +1,22 @@ +from .utils import MetricResult, Weights +from ..edit_distance import levenshtein +from ..normalize import words_normalized + + +def word_accuracy( + reference: str, compared: str, weights: Weights = Weights(1, 1, 1) +) -> MetricResult: + reference_seq = list(words_normalized(reference)) + compared_seq = list(words_normalized(compared)) + + weighted_errors = levenshtein(reference_seq, compared_seq) + n_ref = len(reference_seq) + n_cmp = len(compared_seq) + + return MetricResult( + metric=word_accuracy.__name__, + weights=weights, + weighted_errors=int(weighted_errors), + reference_elements=n_ref, + compared_elements=n_cmp, + ) diff --git a/qurator/dinglehopper/metrics/word_error_rate.py b/qurator/dinglehopper/metrics/word_error_rate.py deleted file mode 100644 index 14d3784..0000000 --- a/qurator/dinglehopper/metrics/word_error_rate.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import division - -from typing import Iterable, Tuple - -from multimethod import multimethod - -from ..edit_distance import levenshtein -from ..extracted_text import ExtractedText -from ..normalize import words_normalized - - -@multimethod -def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: - reference_seq = list(words_normalized(reference)) - compared_seq = list(words_normalized(compared)) - return word_error_rate_n(reference_seq, compared_seq) - - -@multimethod -def word_error_rate_n( - reference: ExtractedText, compared: ExtractedText -) -> Tuple[float, int]: - return word_error_rate_n(reference.text, compared.text) - - -@multimethod -def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]: - reference_seq = list(reference) - compared_seq = list(compared) - - d = levenshtein(reference_seq, compared_seq) - n = len(reference_seq) - - if d == 0: - return 0, n - if n == 0: - return float("inf"), n - return d / n, n - - -def word_error_rate(reference, compared) -> float: - wer, _ = word_error_rate_n(reference, compared) - return wer diff --git a/qurator/dinglehopper/ocr_files.py b/qurator/dinglehopper/ocr_files.py index 5271727..1f319f5 100644 --- a/qurator/dinglehopper/ocr_files.py +++ b/qurator/dinglehopper/ocr_files.py @@ -163,8 +163,8 @@ def extract(filename, *, textequiv_level="region"): return alto_extract(tree) -def text(filename): - return extract(filename).text +def text(filename, *args, **kwargs): + return extract(filename, *args, **kwargs).text if __name__ == "__main__": diff --git a/qurator/dinglehopper/ocrd-tool.json b/qurator/dinglehopper/ocrd-tool.json index 1e2b9b0..0537db8 100644 --- a/qurator/dinglehopper/ocrd-tool.json +++ b/qurator/dinglehopper/ocrd-tool.json @@ -19,9 +19,10 @@ ], "parameters": { "metrics": { - "type": "boolean", - "default": true, - "description": "Enable/disable metrics and green/red" + "type": "string", + "enum": ["", "boc", "boc,bow", "bow", "ca", "ca,boc", "ca,boc,bow", "ca,bow", "ca,wa", "ca,wa,boc", "ca,wa,boc,bow", "ca,wa,bow", "ca,wer", "ca,wer,boc", "ca,wer,boc,bow", "ca,wer,bow", "cer", "cer,boc", "cer,boc,bow", "cer,bow", "cer,wa", "cer,wa,boc", "cer,wa,boc,bow", "cer,wa,bow", "cer,wer", "cer,wer,boc", "cer,wer,boc,bow", "cer,wer,bow", "wa", "wa,boc", "wa,boc,bow", "wa,bow", "wer", "wer,boc", "wer,boc,bow", "wer,bow"], + "default": "cer,wer", + "description": "Enable different metrics like ca|cer, wa|wer, boc and bow." }, "textequiv_level": { "type": "string", diff --git a/qurator/dinglehopper/templates/report.html.j2 b/qurator/dinglehopper/templates/report.html.j2 index 0c2f464..46a952f 100644 --- a/qurator/dinglehopper/templates/report.html.j2 +++ b/qurator/dinglehopper/templates/report.html.j2 @@ -6,7 +6,7 @@