diff --git a/qurator/dinglehopper/character_error_rate.py b/qurator/dinglehopper/character_error_rate.py index e99f391..29826e3 100644 --- a/qurator/dinglehopper/character_error_rate.py +++ b/qurator/dinglehopper/character_error_rate.py @@ -6,6 +6,7 @@ from typing import Tuple from uniseg.graphemecluster import grapheme_clusters from qurator.dinglehopper.edit_distance import distance +from qurator.dinglehopper.ocr_files import ExtractedText def character_error_rate_n(reference, compared) -> Tuple[float, int]: @@ -14,12 +15,13 @@ def character_error_rate_n(reference, compared) -> Tuple[float, int]: :return: character error rate and length of the reference """ + if isinstance(reference, str): + return character_error_rate_n( + ExtractedText.from_text(reference), + compared) + d = distance(reference, compared) - # XXX - from .cli import ExtractedText - if isinstance(reference, ExtractedText): - reference = reference.text - n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference)))) + n = len(list(grapheme_clusters(reference.text))) if d == 0: return 0, n diff --git a/qurator/dinglehopper/edit_distance.py b/qurator/dinglehopper/edit_distance.py index 284b676..bc607a9 100644 --- a/qurator/dinglehopper/edit_distance.py +++ b/qurator/dinglehopper/edit_distance.py @@ -7,6 +7,7 @@ from typing import Sequence, Tuple import numpy as np from uniseg.graphemecluster import grapheme_clusters +from .ocr_files import ExtractedText def levenshtein_matrix(seq1: Sequence, seq2: Sequence): """Compute the matrix commonly computed to produce the Levenshtein distance. @@ -75,12 +76,12 @@ def distance(s1, s2): Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme clusters. This should be the correct way to compare two Unicode strings. """ - # XXX - from .cli import ExtractedText + if isinstance(s1, ExtractedText): s1 = s1.text if isinstance(s2, ExtractedText): s2 = s2.text + s1 = list(grapheme_clusters(unicodedata.normalize('NFC', s1))) s2 = list(grapheme_clusters(unicodedata.normalize('NFC', s2))) return levenshtein(s1, s2)