From 381fe7cb6bb58ce6186b4e243ab386632355cdb8 Mon Sep 17 00:00:00 2001 From: Benjamin Rosemann Date: Fri, 11 Jun 2021 10:21:23 +0200 Subject: [PATCH] Switch to result tuple instead of multiple return parameters --- .../metrics/bag_of_chars_accuracy.py | 30 ++----- .../metrics/bag_of_words_accuracy.py | 21 ++--- qurator/dinglehopper/metrics/utils.py | 37 ++++++-- .../tests/metrics/test_bag_accuracy.py | 87 ++++++++++++------- 4 files changed, 99 insertions(+), 76 deletions(-) diff --git a/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py index dd6a030..c9cd9f2 100644 --- a/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py +++ b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py @@ -1,35 +1,17 @@ from collections import Counter -from typing import Tuple, Union from unicodedata import normalize -from multimethod import multimethod from uniseg.graphemecluster import grapheme_clusters -from .utils import bag_accuracy, Weights -from .. import ExtractedText +from .utils import bag_accuracy, MetricResult, Weights def bag_of_chars_accuracy( - reference: Union[str, ExtractedText], - compared: Union[str, ExtractedText], - weights: Weights, -) -> float: - acc, _ = bag_of_chars_accuracy_n(reference, compared, weights) - return acc - - -@multimethod -def bag_of_chars_accuracy_n( reference: str, compared: str, weights: Weights -) -> Tuple[float, int]: +) -> MetricResult: reference_chars = Counter(grapheme_clusters(normalize("NFC", reference))) compared_chars = Counter(grapheme_clusters(normalize("NFC", compared))) - e, n = bag_accuracy(reference_chars, compared_chars, weights) - return (float("inf") if n == 0 else 1 - e / n), n - - -@multimethod -def bag_of_chars_accuracy_n( - reference: ExtractedText, compared: ExtractedText, weights: Weights -) -> Tuple[float, int]: - return bag_of_chars_accuracy_n(reference.text, compared.text, weights) + result = bag_accuracy(reference_chars, compared_chars, weights) + return MetricResult( + **{**result._asdict(), "metric": bag_of_chars_accuracy.__name__} + ) diff --git a/qurator/dinglehopper/metrics/bag_of_words_accuracy.py b/qurator/dinglehopper/metrics/bag_of_words_accuracy.py index 7e5f315..1b0e763 100644 --- a/qurator/dinglehopper/metrics/bag_of_words_accuracy.py +++ b/qurator/dinglehopper/metrics/bag_of_words_accuracy.py @@ -1,7 +1,7 @@ from collections import Counter -from typing import Tuple, Union +from typing import Union -from .utils import bag_accuracy, Weights +from .utils import bag_accuracy, MetricResult, Weights from .. import ExtractedText from ..normalize import words_normalized @@ -10,21 +10,14 @@ def bag_of_words_accuracy( reference: Union[str, ExtractedText], compared: Union[str, ExtractedText], weights: Weights, -) -> float: - acc, _ = bag_of_words_accuracy_n(reference, compared, weights) - return acc - - -def bag_of_words_accuracy_n( - reference: Union[str, ExtractedText], - compared: Union[str, ExtractedText], - weights: Weights, -) -> Tuple[float, int]: +) -> MetricResult: if isinstance(reference, ExtractedText): reference = reference.text if isinstance(compared, ExtractedText): compared = compared.text reference_words = Counter(words_normalized(reference)) compared_words = Counter(words_normalized(compared)) - e, n = bag_accuracy(reference_words, compared_words, weights) - return (float("inf") if n == 0 else 1 - e / n), n + result = bag_accuracy(reference_words, compared_words, weights) + return MetricResult( + **{**result._asdict(), "metric": bag_of_words_accuracy.__name__} + ) diff --git a/qurator/dinglehopper/metrics/utils.py b/qurator/dinglehopper/metrics/utils.py index cfb764e..b3ca5bc 100644 --- a/qurator/dinglehopper/metrics/utils.py +++ b/qurator/dinglehopper/metrics/utils.py @@ -1,5 +1,5 @@ from collections import Counter -from typing import NamedTuple, Tuple +from typing import NamedTuple class Weights(NamedTuple): @@ -10,9 +10,29 @@ class Weights(NamedTuple): replacements: int = 1 +class MetricResult(NamedTuple): + """Represent a result from a metric calculation.""" + + metric: str + weights: Weights + weighted_errors: int + reference_elements: int + compared_elements: int + + @property + def accuracy(self) -> float: + return 1 - self.error_rate + + @property + def error_rate(self) -> float: + if self.reference_elements <= 0: + return float("inf") + return self.weighted_errors / self.reference_elements + + def bag_accuracy( reference: Counter, compared: Counter, weights: Weights -) -> Tuple[int, int]: +) -> MetricResult: """Calculates the the weighted errors for two bags (Counter). Basic algorithm idea: @@ -24,9 +44,10 @@ def bag_accuracy( :param reference: Bag used as reference (ground truth). :param compared: Bag used to compare (ocr). :param weights: Weights/costs for editing operations. - :return: weighted errors and number of elements in reference. + :return: Tuple representing the results of this metric. """ - n = sum(reference.values()) + n_ref = sum(reference.values()) + n_cmp = sum(compared.values()) deletes = sum((reference - compared).values()) inserts = sum((compared - reference).values()) replacements = 0 @@ -38,4 +59,10 @@ def bag_accuracy( + weights.inserts * inserts + weights.replacements * replacements ) - return weighted_errors, n + return MetricResult( + metric=bag_accuracy.__name__, + weights=weights, + weighted_errors=weighted_errors, + reference_elements=n_ref, + compared_elements=n_cmp, + ) diff --git a/qurator/dinglehopper/tests/metrics/test_bag_accuracy.py b/qurator/dinglehopper/tests/metrics/test_bag_accuracy.py index 345e0bd..daa8721 100644 --- a/qurator/dinglehopper/tests/metrics/test_bag_accuracy.py +++ b/qurator/dinglehopper/tests/metrics/test_bag_accuracy.py @@ -4,7 +4,12 @@ from collections import Counter import pytest -from ...metrics import bag_of_chars_accuracy_n, bag_of_words_accuracy_n, Weights +from ...metrics import ( + bag_of_chars_accuracy, + bag_of_words_accuracy, + MetricResult, + Weights, +) from ...metrics.utils import bag_accuracy @@ -18,75 +23,93 @@ def ex_weights(): ) +def verify_metric_result( + result: MetricResult, + metric: str, + errors: int, + n_ref: int, + n_cmp: int, + weights: Weights, +): + assert result.metric == metric + assert result.weights == weights + assert result.weighted_errors == errors + assert result.reference_elements == n_ref + assert result.compared_elements == n_cmp + + +CASE_PARAMS = "s1,s2, s1_n, s2_n, ex_err" + SIMPLE_CASES = ( - ("", "", 0, (0, 0, 0)), - ("abc", "", 3, (3, 3, 3)), - ("", "abc", 0, (3, 0, 3)), - ("abc", "abc", 3, (0, 0, 0)), - ("abc", "ab", 3, (1, 1, 1)), - ("abc", "abcd", 3, (1, 0, 1)), - ("abc", "abd", 3, (1, 1, 2)), + ("", "", 0, 0, (0, 0, 0)), + ("abc", "", 3, 0, (3, 3, 3)), + ("", "abc", 0, 3, (3, 0, 3)), + ("abc", "abc", 3, 3, (0, 0, 0)), + ("abc", "ab", 3, 2, (1, 1, 1)), + ("abc", "abcd", 3, 4, (1, 0, 1)), + ("abc", "abd", 3, 3, (1, 1, 2)), ) @pytest.mark.parametrize( - "s1,s2, ex_n, ex_err", + CASE_PARAMS, [ *SIMPLE_CASES, - (("a", "b", "c", "d", "e"), ("a", "b", "c", "d", ("e", "´")), 5, (1, 1, 2)), - (range(5), range(6), 5, (1, 0, 1)), + (("a", "b", "c", "d", "e"), ("a", "b", "c", "d", ("e", "´")), 5, 5, (1, 1, 2)), + (range(5), range(6), 5, 6, (1, 0, 1)), ], ) -def test_bag_accuracy_algorithm(s1, s2, ex_n, ex_err, ex_weights): +def test_bag_accuracy_algorithm(s1, s2, s1_n, s2_n, ex_err, ex_weights): """Test the main algorithm for calculating the bag accuracy.""" for weights, expected_errors in zip(ex_weights, (0, *ex_err)): - e, n = bag_accuracy(Counter(s1), Counter(s2), weights=weights) - assert n == ex_n, f"{n} == {ex_n} for {weights}" - assert e == expected_errors, f"{e} == {expected_errors} for {weights}" + metric_result = bag_accuracy(Counter(s1), Counter(s2), weights=weights) + verify_metric_result( + metric_result, "bag_accuracy", expected_errors, s1_n, s2_n, weights + ) @pytest.mark.parametrize( - "s1,s2, ex_n, ex_err", + CASE_PARAMS, [ *SIMPLE_CASES, - ("Schlyñ", "Schlym̃", 6, (1, 1, 2)), + ("Schlyñ", "Schlym̃", 6, 6, (1, 1, 2)), ( unicodedata.normalize("NFC", "Schlyñ lorem ipsum."), unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"), 19, + 19, (1, 1, 2), ), ], ) -def test_bag_of_chars_accuracy_n(s1, s2, ex_n, ex_err, ex_weights): +def test_bag_of_chars_accuracy(s1, s2, s1_n, s2_n, ex_err, ex_weights): """Test the special behaviour of the char differentiation. As the algorithm and the char normalization is implemented elsewhere we are currently only testing that the corresponding algorithms are called. """ for weights, expected_errors in zip(ex_weights, (0, *ex_err)): - acc, n = bag_of_chars_accuracy_n(s1, s2, weights) - assert n == ex_n, f"{n} == {ex_n} for {weights}" - if ex_n == 0: - assert math.isinf(acc) - else: - assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}" + result = bag_of_chars_accuracy(s1, s2, weights) + verify_metric_result( + result, "bag_of_chars_accuracy", expected_errors, s1_n, s2_n, weights + ) @pytest.mark.parametrize( - "s1,s2, ex_n, ex_err", + CASE_PARAMS, [ *SIMPLE_CASES, - ("Schlyñ", "Schlym̃", 6, (1, 1, 2)), + ("Schlyñ", "Schlym̃", 6, 6, (1, 1, 2)), ( unicodedata.normalize("NFC", "Schlyñ lorem ipsum."), unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"), 3, + 3, (0, 0, 0), ), ], ) -def test_bag_of_words_accuracy_n(s1, s2, ex_n, ex_err, ex_weights): +def test_bag_of_words_accuracy(s1, s2, s1_n, s2_n, ex_err, ex_weights): """Test the special behaviour of the word differentiation. As the algorithm and the word splitting is implemented elsewhere @@ -96,9 +119,7 @@ def test_bag_of_words_accuracy_n(s1, s2, ex_n, ex_err, ex_weights): s1 = " ".join(s1) s2 = " ".join(s2) for weights, expected_errors in zip(ex_weights, (0, *ex_err)): - acc, n = bag_of_words_accuracy_n(s1, s2, weights) - assert n == ex_n, f"{n} == {ex_n} for {weights}" - if ex_n == 0: - assert math.isinf(acc) - else: - assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}" + result = bag_of_words_accuracy(s1, s2, weights) + verify_metric_result( + result, "bag_of_words_accuracy", expected_errors, s1_n, s2_n, weights + )