diff --git a/src/dinglehopper/character_error_rate.py b/src/dinglehopper/character_error_rate.py index 88a88f8..04e4bfe 100644 --- a/src/dinglehopper/character_error_rate.py +++ b/src/dinglehopper/character_error_rate.py @@ -20,14 +20,7 @@ def character_error_rate_n( :return: character error rate and length of the reference """ - d = distance(reference, compared) - n = len(reference) - - if d == 0: - return 0, n - if n == 0: - return float("inf"), n - return d / n, n + return distance(reference, compared), len(reference) # XXX Should we really count newlines here? diff --git a/src/dinglehopper/edit_distance.py b/src/dinglehopper/edit_distance.py index ec564ae..988849c 100644 --- a/src/dinglehopper/edit_distance.py +++ b/src/dinglehopper/edit_distance.py @@ -9,18 +9,18 @@ from .extracted_text import ExtractedText @multimethod -def distance(seq1: List[str], seq2: List[str]) -> int: +def distance(seq1: List[str], seq2: List[str]) -> float: """Compute the Levenshtein edit distance between two lists of grapheme clusters. This assumes that the grapheme clusters are already normalized. Use distance(str, str) instead if you need to compare two Unicode strings. """ - return Levenshtein.distance(seq1, seq2) + return Levenshtein.normalized_distance(seq1, seq2) @distance.register -def _(s1: str, s2: str) -> int: +def _(s1: str, s2: str) -> float: """Compute the Levenshtein edit distance between two Unicode strings Note that this is different from levenshtein() as this function knows about Unicode @@ -29,12 +29,12 @@ def _(s1: str, s2: str) -> int: """ seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", s1))) seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", s2))) - return Levenshtein.distance(seq1, seq2) + return Levenshtein.normalized_distance(seq1, seq2) @distance.register -def _(s1: ExtractedText, s2: ExtractedText) -> int: - return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters) +def _(s1: ExtractedText, s2: ExtractedText) -> float: + return Levenshtein.normalized_distance(s1.grapheme_clusters, s2.grapheme_clusters) def editops(word1, word2): diff --git a/src/dinglehopper/word_error_rate.py b/src/dinglehopper/word_error_rate.py index ec039b3..abaa168 100644 --- a/src/dinglehopper/word_error_rate.py +++ b/src/dinglehopper/word_error_rate.py @@ -96,15 +96,10 @@ def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]: reference_seq = list(reference) compared_seq = list(compared) - d = Levenshtein.distance(reference_seq, compared_seq) + d = Levenshtein.normalized_distance(reference_seq, compared_seq) n = len(reference_seq) - if d == 0: - return 0, n - if n == 0: - return float("inf"), n - return d / n, n - + return d, n def word_error_rate(reference: T, compared: T) -> float: wer: float