Added comments

pull/60/head
Benjamin Rosemann 4 years ago
parent cee7b6891b
commit 40f23b8482

@ -9,6 +9,16 @@ from .utils import bag_accuracy, MetricResult, Weights
def bag_of_chars_accuracy( def bag_of_chars_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 0, 1) reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult: ) -> MetricResult:
"""Compute Bag of Chars accuracy and error rate.
We are using sets to calculate the errors.
See :func:`bag_accuracy` for details.
:param reference: String used as reference (e.g. ground truth).
:param compared: String that gets evaluated (e.g. ocr result).
:param weights: Weights/costs for editing operations.
:return: Class representing the results of this metric.
"""
reference_chars: Counter = Counter(grapheme_clusters(normalize("NFC", reference))) reference_chars: Counter = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars: Counter = Counter(grapheme_clusters(normalize("NFC", compared))) compared_chars: Counter = Counter(grapheme_clusters(normalize("NFC", compared)))
return bag_accuracy( return bag_accuracy(

@ -7,6 +7,16 @@ from ..normalize import words_normalized
def bag_of_words_accuracy( def bag_of_words_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 0, 1) reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult: ) -> MetricResult:
"""Compute Bag of Words accuracy and error rate.
We are using sets to calculate the errors.
See :func:`bag_accuracy` for details.
:param reference: String used as reference (e.g. ground truth).
:param compared: String that gets evaluated (e.g. ocr result).
:param weights: Weights/costs for editing operations.
:return: Class representing the results of this metric.
"""
reference_words: Counter = Counter(words_normalized(reference)) reference_words: Counter = Counter(words_normalized(reference))
compared_words: Counter = Counter(words_normalized(compared)) compared_words: Counter = Counter(words_normalized(compared))
return bag_accuracy( return bag_accuracy(

@ -8,9 +8,15 @@ def character_accuracy(
) -> MetricResult: ) -> MetricResult:
"""Compute character accuracy and error rate. """Compute character accuracy and error rate.
:return: NamedTuple representing the results of this metric. We are using the Levenshtein distance between reference and compared.
"""
:param reference: String used as reference (e.g. ground truth).
:param compared: String that gets evaluated (e.g. ocr result).
:param weights: Weights/costs for editing operations (not supported yet).
:return: Class representing the results of this metric.
"""
if weights != Weights(1, 1, 1):
raise NotImplementedError("Setting weights is not supported yet.")
weighted_errors = distance(reference, compared) weighted_errors = distance(reference, compared)
n_ref = len(chars_normalized(reference)) n_ref = len(chars_normalized(reference))
n_cmp = len(chars_normalized(compared)) n_cmp = len(chars_normalized(compared))

@ -6,25 +6,35 @@ class Weights(NamedTuple):
"""Represent weights/costs for editing operations.""" """Represent weights/costs for editing operations."""
deletes: int = 1 deletes: int = 1
"""Cost for a delete ad -> a."""
inserts: int = 1 inserts: int = 1
"""Cost for an insert a -> ai."""
replacements: int = 1 replacements: int = 1
"""Cost for a replacement ab -> ar."""
class MetricResult(NamedTuple): class MetricResult(NamedTuple):
"""Represent a result from a metric calculation.""" """Represent a result from a metric calculation."""
metric: str metric: str
"""Name of the metric that calculated this results."""
weights: Weights weights: Weights
"""The `Weights`/costs used to calculate this result."""
weighted_errors: int weighted_errors: int
"""The weighted errors calculated by the metric."""
reference_elements: int reference_elements: int
"""Number of elements in the reference string."""
compared_elements: int compared_elements: int
"""Number of elements in the string that is evaluated."""
@property @property
def accuracy(self) -> float: def accuracy(self) -> float:
"""The accuracy calculated as 1 - errors / reference_elements."""
return 1 - self.error_rate return 1 - self.error_rate
@property @property
def error_rate(self) -> float: def error_rate(self) -> float:
"""The error rate calculated by errors / reference_elements."""
if self.reference_elements <= 0 and self.compared_elements <= 0: if self.reference_elements <= 0 and self.compared_elements <= 0:
return 0 return 0
elif self.reference_elements <= 0: elif self.reference_elements <= 0:
@ -34,7 +44,7 @@ class MetricResult(NamedTuple):
def get_dict(self) -> Dict: def get_dict(self) -> Dict:
"""Combines the properties to a dictionary. """Combines the properties to a dictionary.
We deviate from the builtin _asdict() function by including our properties. We deviate from the builtin `_asdict()` function by including our properties.
""" """
return { return {
**{key: value for key, value in self._asdict().items()}, **{key: value for key, value in self._asdict().items()},
@ -50,19 +60,19 @@ def bag_accuracy(
weights: Weights, weights: Weights,
metric: str = "bag_accuracy", metric: str = "bag_accuracy",
) -> MetricResult: ) -> MetricResult:
"""Calculates the the weighted errors for two bags (Counter). """Calculates the the weighted errors for two bags (Multiset, `Counter`).
Basic algorithm idea: Basic algorithm idea:
- All elements in reference not occurring in compared are considered deletes. - All elements in `reference` not occurring in `compared` are considered deletes.
- All elements in compared not occurring in reference are considered inserts. - All elements in `compared` not occurring in `reference` are considered inserts.
- When the cost for one replacement is lower than that of one insert and one delete - When the cost for one replacement is lower than that of one insert and one delete
we can substitute pairs of deletes and inserts with one replacement. we can substitute pairs of deletes and inserts with one replacement.
:param reference: Bag used as reference (ground truth). :param reference: Bag used as reference (ground truth).
:param compared: Bag used to compare (ocr). :param compared: Bag used to compare (ocr).
:param weights: Weights/costs for editing operations. :param weights: `Weights`/costs for editing operations.
:param metric: Name of the (original) metric. :param metric: Name of the (original) metric.
:return: NamedTuple representing the results of this metric. :return: `NamedTuple` representing the results of this metric.
""" """
n_ref = sum(reference.values()) n_ref = sum(reference.values())
n_cmp = sum(compared.values()) n_cmp = sum(compared.values())

@ -6,6 +6,18 @@ from ..normalize import words_normalized
def word_accuracy( def word_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 1, 1) reference: str, compared: str, weights: Weights = Weights(1, 1, 1)
) -> MetricResult: ) -> MetricResult:
"""Compute word accuracy and error rate.
We are using the Levenshtein distance between reference.
:param reference: String used as reference (e.g. ground truth).
:param compared: String that gets evaluated (e.g. ocr result).
:param weights: Weights/costs for editing operations (not supported yet).
:return: Class representing the results of this metric.
"""
if weights != Weights(1, 1, 1):
raise NotImplementedError("Setting weights is not supported yet.")
reference_seq = list(words_normalized(reference)) reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared)) compared_seq = list(words_normalized(compared))

Loading…
Cancel
Save