Switch to result tuple instead of multiple return parameters

pull/60/head
Benjamin Rosemann 4 years ago
parent 974ca3e5c0
commit 381fe7cb6b

@ -1,35 +1,17 @@
from collections import Counter from collections import Counter
from typing import Tuple, Union
from unicodedata import normalize from unicodedata import normalize
from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
from .utils import bag_accuracy, Weights from .utils import bag_accuracy, MetricResult, Weights
from .. import ExtractedText
def bag_of_chars_accuracy( def bag_of_chars_accuracy(
reference: Union[str, ExtractedText],
compared: Union[str, ExtractedText],
weights: Weights,
) -> float:
acc, _ = bag_of_chars_accuracy_n(reference, compared, weights)
return acc
@multimethod
def bag_of_chars_accuracy_n(
reference: str, compared: str, weights: Weights reference: str, compared: str, weights: Weights
) -> Tuple[float, int]: ) -> MetricResult:
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference))) reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared))) compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
e, n = bag_accuracy(reference_chars, compared_chars, weights) result = bag_accuracy(reference_chars, compared_chars, weights)
return (float("inf") if n == 0 else 1 - e / n), n return MetricResult(
**{**result._asdict(), "metric": bag_of_chars_accuracy.__name__}
)
@multimethod
def bag_of_chars_accuracy_n(
reference: ExtractedText, compared: ExtractedText, weights: Weights
) -> Tuple[float, int]:
return bag_of_chars_accuracy_n(reference.text, compared.text, weights)

@ -1,7 +1,7 @@
from collections import Counter from collections import Counter
from typing import Tuple, Union from typing import Union
from .utils import bag_accuracy, Weights from .utils import bag_accuracy, MetricResult, Weights
from .. import ExtractedText from .. import ExtractedText
from ..normalize import words_normalized from ..normalize import words_normalized
@ -10,21 +10,14 @@ def bag_of_words_accuracy(
reference: Union[str, ExtractedText], reference: Union[str, ExtractedText],
compared: Union[str, ExtractedText], compared: Union[str, ExtractedText],
weights: Weights, weights: Weights,
) -> float: ) -> MetricResult:
acc, _ = bag_of_words_accuracy_n(reference, compared, weights)
return acc
def bag_of_words_accuracy_n(
reference: Union[str, ExtractedText],
compared: Union[str, ExtractedText],
weights: Weights,
) -> Tuple[float, int]:
if isinstance(reference, ExtractedText): if isinstance(reference, ExtractedText):
reference = reference.text reference = reference.text
if isinstance(compared, ExtractedText): if isinstance(compared, ExtractedText):
compared = compared.text compared = compared.text
reference_words = Counter(words_normalized(reference)) reference_words = Counter(words_normalized(reference))
compared_words = Counter(words_normalized(compared)) compared_words = Counter(words_normalized(compared))
e, n = bag_accuracy(reference_words, compared_words, weights) result = bag_accuracy(reference_words, compared_words, weights)
return (float("inf") if n == 0 else 1 - e / n), n return MetricResult(
**{**result._asdict(), "metric": bag_of_words_accuracy.__name__}
)

@ -1,5 +1,5 @@
from collections import Counter from collections import Counter
from typing import NamedTuple, Tuple from typing import NamedTuple
class Weights(NamedTuple): class Weights(NamedTuple):
@ -10,9 +10,29 @@ class Weights(NamedTuple):
replacements: int = 1 replacements: int = 1
class MetricResult(NamedTuple):
"""Represent a result from a metric calculation."""
metric: str
weights: Weights
weighted_errors: int
reference_elements: int
compared_elements: int
@property
def accuracy(self) -> float:
return 1 - self.error_rate
@property
def error_rate(self) -> float:
if self.reference_elements <= 0:
return float("inf")
return self.weighted_errors / self.reference_elements
def bag_accuracy( def bag_accuracy(
reference: Counter, compared: Counter, weights: Weights reference: Counter, compared: Counter, weights: Weights
) -> Tuple[int, int]: ) -> MetricResult:
"""Calculates the the weighted errors for two bags (Counter). """Calculates the the weighted errors for two bags (Counter).
Basic algorithm idea: Basic algorithm idea:
@ -24,9 +44,10 @@ def bag_accuracy(
:param reference: Bag used as reference (ground truth). :param reference: Bag used as reference (ground truth).
:param compared: Bag used to compare (ocr). :param compared: Bag used to compare (ocr).
:param weights: Weights/costs for editing operations. :param weights: Weights/costs for editing operations.
:return: weighted errors and number of elements in reference. :return: Tuple representing the results of this metric.
""" """
n = sum(reference.values()) n_ref = sum(reference.values())
n_cmp = sum(compared.values())
deletes = sum((reference - compared).values()) deletes = sum((reference - compared).values())
inserts = sum((compared - reference).values()) inserts = sum((compared - reference).values())
replacements = 0 replacements = 0
@ -38,4 +59,10 @@ def bag_accuracy(
+ weights.inserts * inserts + weights.inserts * inserts
+ weights.replacements * replacements + weights.replacements * replacements
) )
return weighted_errors, n return MetricResult(
metric=bag_accuracy.__name__,
weights=weights,
weighted_errors=weighted_errors,
reference_elements=n_ref,
compared_elements=n_cmp,
)

@ -4,7 +4,12 @@ from collections import Counter
import pytest import pytest
from ...metrics import bag_of_chars_accuracy_n, bag_of_words_accuracy_n, Weights from ...metrics import (
bag_of_chars_accuracy,
bag_of_words_accuracy,
MetricResult,
Weights,
)
from ...metrics.utils import bag_accuracy from ...metrics.utils import bag_accuracy
@ -18,75 +23,93 @@ def ex_weights():
) )
def verify_metric_result(
result: MetricResult,
metric: str,
errors: int,
n_ref: int,
n_cmp: int,
weights: Weights,
):
assert result.metric == metric
assert result.weights == weights
assert result.weighted_errors == errors
assert result.reference_elements == n_ref
assert result.compared_elements == n_cmp
CASE_PARAMS = "s1,s2, s1_n, s2_n, ex_err"
SIMPLE_CASES = ( SIMPLE_CASES = (
("", "", 0, (0, 0, 0)), ("", "", 0, 0, (0, 0, 0)),
("abc", "", 3, (3, 3, 3)), ("abc", "", 3, 0, (3, 3, 3)),
("", "abc", 0, (3, 0, 3)), ("", "abc", 0, 3, (3, 0, 3)),
("abc", "abc", 3, (0, 0, 0)), ("abc", "abc", 3, 3, (0, 0, 0)),
("abc", "ab", 3, (1, 1, 1)), ("abc", "ab", 3, 2, (1, 1, 1)),
("abc", "abcd", 3, (1, 0, 1)), ("abc", "abcd", 3, 4, (1, 0, 1)),
("abc", "abd", 3, (1, 1, 2)), ("abc", "abd", 3, 3, (1, 1, 2)),
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"s1,s2, ex_n, ex_err", CASE_PARAMS,
[ [
*SIMPLE_CASES, *SIMPLE_CASES,
(("a", "b", "c", "d", "e"), ("a", "b", "c", "d", ("e", "´")), 5, (1, 1, 2)), (("a", "b", "c", "d", "e"), ("a", "b", "c", "d", ("e", "´")), 5, 5, (1, 1, 2)),
(range(5), range(6), 5, (1, 0, 1)), (range(5), range(6), 5, 6, (1, 0, 1)),
], ],
) )
def test_bag_accuracy_algorithm(s1, s2, ex_n, ex_err, ex_weights): def test_bag_accuracy_algorithm(s1, s2, s1_n, s2_n, ex_err, ex_weights):
"""Test the main algorithm for calculating the bag accuracy.""" """Test the main algorithm for calculating the bag accuracy."""
for weights, expected_errors in zip(ex_weights, (0, *ex_err)): for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
e, n = bag_accuracy(Counter(s1), Counter(s2), weights=weights) metric_result = bag_accuracy(Counter(s1), Counter(s2), weights=weights)
assert n == ex_n, f"{n} == {ex_n} for {weights}" verify_metric_result(
assert e == expected_errors, f"{e} == {expected_errors} for {weights}" metric_result, "bag_accuracy", expected_errors, s1_n, s2_n, weights
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"s1,s2, ex_n, ex_err", CASE_PARAMS,
[ [
*SIMPLE_CASES, *SIMPLE_CASES,
("Schlyñ", "Schlym̃", 6, (1, 1, 2)), ("Schlyñ", "Schlym̃", 6, 6, (1, 1, 2)),
( (
unicodedata.normalize("NFC", "Schlyñ lorem ipsum."), unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"), unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
19, 19,
19,
(1, 1, 2), (1, 1, 2),
), ),
], ],
) )
def test_bag_of_chars_accuracy_n(s1, s2, ex_n, ex_err, ex_weights): def test_bag_of_chars_accuracy(s1, s2, s1_n, s2_n, ex_err, ex_weights):
"""Test the special behaviour of the char differentiation. """Test the special behaviour of the char differentiation.
As the algorithm and the char normalization is implemented elsewhere As the algorithm and the char normalization is implemented elsewhere
we are currently only testing that the corresponding algorithms are called. we are currently only testing that the corresponding algorithms are called.
""" """
for weights, expected_errors in zip(ex_weights, (0, *ex_err)): for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
acc, n = bag_of_chars_accuracy_n(s1, s2, weights) result = bag_of_chars_accuracy(s1, s2, weights)
assert n == ex_n, f"{n} == {ex_n} for {weights}" verify_metric_result(
if ex_n == 0: result, "bag_of_chars_accuracy", expected_errors, s1_n, s2_n, weights
assert math.isinf(acc) )
else:
assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"s1,s2, ex_n, ex_err", CASE_PARAMS,
[ [
*SIMPLE_CASES, *SIMPLE_CASES,
("Schlyñ", "Schlym̃", 6, (1, 1, 2)), ("Schlyñ", "Schlym̃", 6, 6, (1, 1, 2)),
( (
unicodedata.normalize("NFC", "Schlyñ lorem ipsum."), unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"), unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
3, 3,
3,
(0, 0, 0), (0, 0, 0),
), ),
], ],
) )
def test_bag_of_words_accuracy_n(s1, s2, ex_n, ex_err, ex_weights): def test_bag_of_words_accuracy(s1, s2, s1_n, s2_n, ex_err, ex_weights):
"""Test the special behaviour of the word differentiation. """Test the special behaviour of the word differentiation.
As the algorithm and the word splitting is implemented elsewhere As the algorithm and the word splitting is implemented elsewhere
@ -96,9 +119,7 @@ def test_bag_of_words_accuracy_n(s1, s2, ex_n, ex_err, ex_weights):
s1 = " ".join(s1) s1 = " ".join(s1)
s2 = " ".join(s2) s2 = " ".join(s2)
for weights, expected_errors in zip(ex_weights, (0, *ex_err)): for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
acc, n = bag_of_words_accuracy_n(s1, s2, weights) result = bag_of_words_accuracy(s1, s2, weights)
assert n == ex_n, f"{n} == {ex_n} for {weights}" verify_metric_result(
if ex_n == 0: result, "bag_of_words_accuracy", expected_errors, s1_n, s2_n, weights
assert math.isinf(acc) )
else:
assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}"

Loading…
Cancel
Save