mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-07-05 16:39:59 +02:00
Switch to result tuple instead of multiple return parameters
This commit is contained in:
parent
974ca3e5c0
commit
381fe7cb6b
4 changed files with 99 additions and 76 deletions
|
@ -1,35 +1,17 @@
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Tuple, Union
|
|
||||||
from unicodedata import normalize
|
from unicodedata import normalize
|
||||||
|
|
||||||
from multimethod import multimethod
|
|
||||||
from uniseg.graphemecluster import grapheme_clusters
|
from uniseg.graphemecluster import grapheme_clusters
|
||||||
|
|
||||||
from .utils import bag_accuracy, Weights
|
from .utils import bag_accuracy, MetricResult, Weights
|
||||||
from .. import ExtractedText
|
|
||||||
|
|
||||||
|
|
||||||
def bag_of_chars_accuracy(
|
def bag_of_chars_accuracy(
|
||||||
reference: Union[str, ExtractedText],
|
|
||||||
compared: Union[str, ExtractedText],
|
|
||||||
weights: Weights,
|
|
||||||
) -> float:
|
|
||||||
acc, _ = bag_of_chars_accuracy_n(reference, compared, weights)
|
|
||||||
return acc
|
|
||||||
|
|
||||||
|
|
||||||
@multimethod
|
|
||||||
def bag_of_chars_accuracy_n(
|
|
||||||
reference: str, compared: str, weights: Weights
|
reference: str, compared: str, weights: Weights
|
||||||
) -> Tuple[float, int]:
|
) -> MetricResult:
|
||||||
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
|
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
|
||||||
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
|
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
|
||||||
e, n = bag_accuracy(reference_chars, compared_chars, weights)
|
result = bag_accuracy(reference_chars, compared_chars, weights)
|
||||||
return (float("inf") if n == 0 else 1 - e / n), n
|
return MetricResult(
|
||||||
|
**{**result._asdict(), "metric": bag_of_chars_accuracy.__name__}
|
||||||
|
)
|
||||||
@multimethod
|
|
||||||
def bag_of_chars_accuracy_n(
|
|
||||||
reference: ExtractedText, compared: ExtractedText, weights: Weights
|
|
||||||
) -> Tuple[float, int]:
|
|
||||||
return bag_of_chars_accuracy_n(reference.text, compared.text, weights)
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Tuple, Union
|
from typing import Union
|
||||||
|
|
||||||
from .utils import bag_accuracy, Weights
|
from .utils import bag_accuracy, MetricResult, Weights
|
||||||
from .. import ExtractedText
|
from .. import ExtractedText
|
||||||
from ..normalize import words_normalized
|
from ..normalize import words_normalized
|
||||||
|
|
||||||
|
@ -10,21 +10,14 @@ def bag_of_words_accuracy(
|
||||||
reference: Union[str, ExtractedText],
|
reference: Union[str, ExtractedText],
|
||||||
compared: Union[str, ExtractedText],
|
compared: Union[str, ExtractedText],
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
) -> float:
|
) -> MetricResult:
|
||||||
acc, _ = bag_of_words_accuracy_n(reference, compared, weights)
|
|
||||||
return acc
|
|
||||||
|
|
||||||
|
|
||||||
def bag_of_words_accuracy_n(
|
|
||||||
reference: Union[str, ExtractedText],
|
|
||||||
compared: Union[str, ExtractedText],
|
|
||||||
weights: Weights,
|
|
||||||
) -> Tuple[float, int]:
|
|
||||||
if isinstance(reference, ExtractedText):
|
if isinstance(reference, ExtractedText):
|
||||||
reference = reference.text
|
reference = reference.text
|
||||||
if isinstance(compared, ExtractedText):
|
if isinstance(compared, ExtractedText):
|
||||||
compared = compared.text
|
compared = compared.text
|
||||||
reference_words = Counter(words_normalized(reference))
|
reference_words = Counter(words_normalized(reference))
|
||||||
compared_words = Counter(words_normalized(compared))
|
compared_words = Counter(words_normalized(compared))
|
||||||
e, n = bag_accuracy(reference_words, compared_words, weights)
|
result = bag_accuracy(reference_words, compared_words, weights)
|
||||||
return (float("inf") if n == 0 else 1 - e / n), n
|
return MetricResult(
|
||||||
|
**{**result._asdict(), "metric": bag_of_words_accuracy.__name__}
|
||||||
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import NamedTuple, Tuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
class Weights(NamedTuple):
|
class Weights(NamedTuple):
|
||||||
|
@ -10,9 +10,29 @@ class Weights(NamedTuple):
|
||||||
replacements: int = 1
|
replacements: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
class MetricResult(NamedTuple):
|
||||||
|
"""Represent a result from a metric calculation."""
|
||||||
|
|
||||||
|
metric: str
|
||||||
|
weights: Weights
|
||||||
|
weighted_errors: int
|
||||||
|
reference_elements: int
|
||||||
|
compared_elements: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def accuracy(self) -> float:
|
||||||
|
return 1 - self.error_rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error_rate(self) -> float:
|
||||||
|
if self.reference_elements <= 0:
|
||||||
|
return float("inf")
|
||||||
|
return self.weighted_errors / self.reference_elements
|
||||||
|
|
||||||
|
|
||||||
def bag_accuracy(
|
def bag_accuracy(
|
||||||
reference: Counter, compared: Counter, weights: Weights
|
reference: Counter, compared: Counter, weights: Weights
|
||||||
) -> Tuple[int, int]:
|
) -> MetricResult:
|
||||||
"""Calculates the the weighted errors for two bags (Counter).
|
"""Calculates the the weighted errors for two bags (Counter).
|
||||||
|
|
||||||
Basic algorithm idea:
|
Basic algorithm idea:
|
||||||
|
@ -24,9 +44,10 @@ def bag_accuracy(
|
||||||
:param reference: Bag used as reference (ground truth).
|
:param reference: Bag used as reference (ground truth).
|
||||||
:param compared: Bag used to compare (ocr).
|
:param compared: Bag used to compare (ocr).
|
||||||
:param weights: Weights/costs for editing operations.
|
:param weights: Weights/costs for editing operations.
|
||||||
:return: weighted errors and number of elements in reference.
|
:return: Tuple representing the results of this metric.
|
||||||
"""
|
"""
|
||||||
n = sum(reference.values())
|
n_ref = sum(reference.values())
|
||||||
|
n_cmp = sum(compared.values())
|
||||||
deletes = sum((reference - compared).values())
|
deletes = sum((reference - compared).values())
|
||||||
inserts = sum((compared - reference).values())
|
inserts = sum((compared - reference).values())
|
||||||
replacements = 0
|
replacements = 0
|
||||||
|
@ -38,4 +59,10 @@ def bag_accuracy(
|
||||||
+ weights.inserts * inserts
|
+ weights.inserts * inserts
|
||||||
+ weights.replacements * replacements
|
+ weights.replacements * replacements
|
||||||
)
|
)
|
||||||
return weighted_errors, n
|
return MetricResult(
|
||||||
|
metric=bag_accuracy.__name__,
|
||||||
|
weights=weights,
|
||||||
|
weighted_errors=weighted_errors,
|
||||||
|
reference_elements=n_ref,
|
||||||
|
compared_elements=n_cmp,
|
||||||
|
)
|
||||||
|
|
|
@ -4,7 +4,12 @@ from collections import Counter
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ...metrics import bag_of_chars_accuracy_n, bag_of_words_accuracy_n, Weights
|
from ...metrics import (
|
||||||
|
bag_of_chars_accuracy,
|
||||||
|
bag_of_words_accuracy,
|
||||||
|
MetricResult,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
from ...metrics.utils import bag_accuracy
|
from ...metrics.utils import bag_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,75 +23,93 @@ def ex_weights():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_metric_result(
|
||||||
|
result: MetricResult,
|
||||||
|
metric: str,
|
||||||
|
errors: int,
|
||||||
|
n_ref: int,
|
||||||
|
n_cmp: int,
|
||||||
|
weights: Weights,
|
||||||
|
):
|
||||||
|
assert result.metric == metric
|
||||||
|
assert result.weights == weights
|
||||||
|
assert result.weighted_errors == errors
|
||||||
|
assert result.reference_elements == n_ref
|
||||||
|
assert result.compared_elements == n_cmp
|
||||||
|
|
||||||
|
|
||||||
|
CASE_PARAMS = "s1,s2, s1_n, s2_n, ex_err"
|
||||||
|
|
||||||
SIMPLE_CASES = (
|
SIMPLE_CASES = (
|
||||||
("", "", 0, (0, 0, 0)),
|
("", "", 0, 0, (0, 0, 0)),
|
||||||
("abc", "", 3, (3, 3, 3)),
|
("abc", "", 3, 0, (3, 3, 3)),
|
||||||
("", "abc", 0, (3, 0, 3)),
|
("", "abc", 0, 3, (3, 0, 3)),
|
||||||
("abc", "abc", 3, (0, 0, 0)),
|
("abc", "abc", 3, 3, (0, 0, 0)),
|
||||||
("abc", "ab", 3, (1, 1, 1)),
|
("abc", "ab", 3, 2, (1, 1, 1)),
|
||||||
("abc", "abcd", 3, (1, 0, 1)),
|
("abc", "abcd", 3, 4, (1, 0, 1)),
|
||||||
("abc", "abd", 3, (1, 1, 2)),
|
("abc", "abd", 3, 3, (1, 1, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"s1,s2, ex_n, ex_err",
|
CASE_PARAMS,
|
||||||
[
|
[
|
||||||
*SIMPLE_CASES,
|
*SIMPLE_CASES,
|
||||||
(("a", "b", "c", "d", "e"), ("a", "b", "c", "d", ("e", "´")), 5, (1, 1, 2)),
|
(("a", "b", "c", "d", "e"), ("a", "b", "c", "d", ("e", "´")), 5, 5, (1, 1, 2)),
|
||||||
(range(5), range(6), 5, (1, 0, 1)),
|
(range(5), range(6), 5, 6, (1, 0, 1)),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_bag_accuracy_algorithm(s1, s2, ex_n, ex_err, ex_weights):
|
def test_bag_accuracy_algorithm(s1, s2, s1_n, s2_n, ex_err, ex_weights):
|
||||||
"""Test the main algorithm for calculating the bag accuracy."""
|
"""Test the main algorithm for calculating the bag accuracy."""
|
||||||
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
|
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
|
||||||
e, n = bag_accuracy(Counter(s1), Counter(s2), weights=weights)
|
metric_result = bag_accuracy(Counter(s1), Counter(s2), weights=weights)
|
||||||
assert n == ex_n, f"{n} == {ex_n} for {weights}"
|
verify_metric_result(
|
||||||
assert e == expected_errors, f"{e} == {expected_errors} for {weights}"
|
metric_result, "bag_accuracy", expected_errors, s1_n, s2_n, weights
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"s1,s2, ex_n, ex_err",
|
CASE_PARAMS,
|
||||||
[
|
[
|
||||||
*SIMPLE_CASES,
|
*SIMPLE_CASES,
|
||||||
("Schlyñ", "Schlym̃", 6, (1, 1, 2)),
|
("Schlyñ", "Schlym̃", 6, 6, (1, 1, 2)),
|
||||||
(
|
(
|
||||||
unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
|
unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
|
||||||
unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
|
unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
|
||||||
19,
|
19,
|
||||||
|
19,
|
||||||
(1, 1, 2),
|
(1, 1, 2),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_bag_of_chars_accuracy_n(s1, s2, ex_n, ex_err, ex_weights):
|
def test_bag_of_chars_accuracy(s1, s2, s1_n, s2_n, ex_err, ex_weights):
|
||||||
"""Test the special behaviour of the char differentiation.
|
"""Test the special behaviour of the char differentiation.
|
||||||
|
|
||||||
As the algorithm and the char normalization is implemented elsewhere
|
As the algorithm and the char normalization is implemented elsewhere
|
||||||
we are currently only testing that the corresponding algorithms are called.
|
we are currently only testing that the corresponding algorithms are called.
|
||||||
"""
|
"""
|
||||||
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
|
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
|
||||||
acc, n = bag_of_chars_accuracy_n(s1, s2, weights)
|
result = bag_of_chars_accuracy(s1, s2, weights)
|
||||||
assert n == ex_n, f"{n} == {ex_n} for {weights}"
|
verify_metric_result(
|
||||||
if ex_n == 0:
|
result, "bag_of_chars_accuracy", expected_errors, s1_n, s2_n, weights
|
||||||
assert math.isinf(acc)
|
)
|
||||||
else:
|
|
||||||
assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"s1,s2, ex_n, ex_err",
|
CASE_PARAMS,
|
||||||
[
|
[
|
||||||
*SIMPLE_CASES,
|
*SIMPLE_CASES,
|
||||||
("Schlyñ", "Schlym̃", 6, (1, 1, 2)),
|
("Schlyñ", "Schlym̃", 6, 6, (1, 1, 2)),
|
||||||
(
|
(
|
||||||
unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
|
unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
|
||||||
unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
|
unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
|
||||||
3,
|
3,
|
||||||
|
3,
|
||||||
(0, 0, 0),
|
(0, 0, 0),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_bag_of_words_accuracy_n(s1, s2, ex_n, ex_err, ex_weights):
|
def test_bag_of_words_accuracy(s1, s2, s1_n, s2_n, ex_err, ex_weights):
|
||||||
"""Test the special behaviour of the word differentiation.
|
"""Test the special behaviour of the word differentiation.
|
||||||
|
|
||||||
As the algorithm and the word splitting is implemented elsewhere
|
As the algorithm and the word splitting is implemented elsewhere
|
||||||
|
@ -96,9 +119,7 @@ def test_bag_of_words_accuracy_n(s1, s2, ex_n, ex_err, ex_weights):
|
||||||
s1 = " ".join(s1)
|
s1 = " ".join(s1)
|
||||||
s2 = " ".join(s2)
|
s2 = " ".join(s2)
|
||||||
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
|
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
|
||||||
acc, n = bag_of_words_accuracy_n(s1, s2, weights)
|
result = bag_of_words_accuracy(s1, s2, weights)
|
||||||
assert n == ex_n, f"{n} == {ex_n} for {weights}"
|
verify_metric_result(
|
||||||
if ex_n == 0:
|
result, "bag_of_words_accuracy", expected_errors, s1_n, s2_n, weights
|
||||||
assert math.isinf(acc)
|
)
|
||||||
else:
|
|
||||||
assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}"
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue