🔍 mypy: Use a compatible syntax for multimethod

pull/111/head
Mike Gerber 12 months ago
parent 8166435958
commit ad316aeabc

@ -30,17 +30,15 @@ def character_error_rate_n(
# XXX Should we really count newlines here? # XXX Should we really count newlines here?
@multimethod @character_error_rate_n.register
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: def _(reference: str, compared: str) -> Tuple[float, int]:
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference))) seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared))) seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
return character_error_rate_n(seq1, seq2) return character_error_rate_n(seq1, seq2)
@multimethod @character_error_rate_n.register
def character_error_rate_n( def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
return character_error_rate_n( return character_error_rate_n(
reference.grapheme_clusters, compared.grapheme_clusters reference.grapheme_clusters, compared.grapheme_clusters
) )

@ -19,8 +19,8 @@ def distance(seq1: List[str], seq2: List[str]):
return Levenshtein.distance(seq1, seq2) return Levenshtein.distance(seq1, seq2)
@multimethod @distance.register
def distance(s1: str, s2: str): def _(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings """Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode Note that this is different from levenshtein() as this function knows about Unicode
@ -32,8 +32,8 @@ def distance(s1: str, s2: str):
return Levenshtein.distance(seq1, seq2) return Levenshtein.distance(seq1, seq2)
@multimethod @distance.register
def distance(s1: ExtractedText, s2: ExtractedText): def _(s1: ExtractedText, s2: ExtractedText):
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters) return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)

@ -60,8 +60,8 @@ def words(s: str):
yield word yield word
@multimethod @words.register
def words(s: ExtractedText): def _(s: ExtractedText):
return words(s.text) return words(s.text)
@ -70,8 +70,8 @@ def words_normalized(s: str):
return words(unicodedata.normalize("NFC", s)) return words(unicodedata.normalize("NFC", s))
@multimethod @words_normalized.register
def words_normalized(s: ExtractedText): def _(s: ExtractedText):
return words_normalized(s.text) return words_normalized(s.text)
@ -82,15 +82,13 @@ def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
return word_error_rate_n(reference_seq, compared_seq) return word_error_rate_n(reference_seq, compared_seq)
@multimethod @word_error_rate_n.register
def word_error_rate_n( def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text) return word_error_rate_n(reference.text, compared.text)
@multimethod @word_error_rate_n.register
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]: def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
reference_seq = list(reference) reference_seq = list(reference)
compared_seq = list(compared) compared_seq = list(compared)

Loading…
Cancel
Save