use Levenshtein.normalized_distance instead of distance

Robert Sachunsky 1 month ago
parent 071e6a8bd1
commit ca5de5729d

@ -20,14 +20,7 @@ def character_error_rate_n(
:return: character error rate and length of the reference :return: character error rate and length of the reference
""" """
d = distance(reference, compared) return distance(reference, compared), len(reference)
n = len(reference)
if d == 0:
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
# XXX Should we really count newlines here? # XXX Should we really count newlines here?

@ -9,18 +9,18 @@ from .extracted_text import ExtractedText
@multimethod @multimethod
def distance(seq1: List[str], seq2: List[str]) -> int: def distance(seq1: List[str], seq2: List[str]) -> float:
"""Compute the Levenshtein edit distance between two lists of grapheme clusters. """Compute the Levenshtein edit distance between two lists of grapheme clusters.
This assumes that the grapheme clusters are already normalized. This assumes that the grapheme clusters are already normalized.
Use distance(str, str) instead if you need to compare two Unicode strings. Use distance(str, str) instead if you need to compare two Unicode strings.
""" """
return Levenshtein.distance(seq1, seq2) return Levenshtein.normalized_distance(seq1, seq2)
@distance.register @distance.register
def _(s1: str, s2: str) -> int: def _(s1: str, s2: str) -> float:
"""Compute the Levenshtein edit distance between two Unicode strings """Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode Note that this is different from levenshtein() as this function knows about Unicode
@ -29,12 +29,12 @@ def _(s1: str, s2: str) -> int:
""" """
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", s1))) seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", s1)))
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", s2))) seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", s2)))
return Levenshtein.distance(seq1, seq2) return Levenshtein.normalized_distance(seq1, seq2)
@distance.register @distance.register
def _(s1: ExtractedText, s2: ExtractedText) -> int: def _(s1: ExtractedText, s2: ExtractedText) -> float:
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters) return Levenshtein.normalized_distance(s1.grapheme_clusters, s2.grapheme_clusters)
def editops(word1, word2): def editops(word1, word2):

@ -96,15 +96,10 @@ def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]:
reference_seq = list(reference) reference_seq = list(reference)
compared_seq = list(compared) compared_seq = list(compared)
d = Levenshtein.distance(reference_seq, compared_seq) d = Levenshtein.normalized_distance(reference_seq, compared_seq)
n = len(reference_seq) n = len(reference_seq)
if d == 0: return d, n
return 0, n
if n == 0:
return float("inf"), n
return d / n, n
def word_error_rate(reference: T, compared: T) -> float: def word_error_rate(reference: T, compared: T) -> float:
wer: float wer: float

Loading…
Cancel
Save