From 7843824eafb5581a3ccdbc24284d049525fdc2f1 Mon Sep 17 00:00:00 2001 From: "Gerber, Mike" Date: Thu, 8 Oct 2020 10:47:20 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A7=20dinglehopper:=20Support=20str=20?= =?UTF-8?q?&=20ExtractedText=20in=20CER=20and=20distance=20functions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- qurator/dinglehopper/character_error_rate.py | 12 +++++++----- qurator/dinglehopper/edit_distance.py | 5 +++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/qurator/dinglehopper/character_error_rate.py b/qurator/dinglehopper/character_error_rate.py index e99f391..29826e3 100644 --- a/qurator/dinglehopper/character_error_rate.py +++ b/qurator/dinglehopper/character_error_rate.py @@ -6,6 +6,7 @@ from typing import Tuple from uniseg.graphemecluster import grapheme_clusters from qurator.dinglehopper.edit_distance import distance +from qurator.dinglehopper.ocr_files import ExtractedText def character_error_rate_n(reference, compared) -> Tuple[float, int]: @@ -14,12 +15,13 @@ def character_error_rate_n(reference, compared) -> Tuple[float, int]: :return: character error rate and length of the reference """ + if isinstance(reference, str): + return character_error_rate_n( + ExtractedText.from_text(reference), + compared) + d = distance(reference, compared) - # XXX - from .cli import ExtractedText - if isinstance(reference, ExtractedText): - reference = reference.text - n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference)))) + n = len(list(grapheme_clusters(reference.text))) if d == 0: return 0, n diff --git a/qurator/dinglehopper/edit_distance.py b/qurator/dinglehopper/edit_distance.py index 284b676..bc607a9 100644 --- a/qurator/dinglehopper/edit_distance.py +++ b/qurator/dinglehopper/edit_distance.py @@ -7,6 +7,7 @@ from typing import Sequence, Tuple import numpy as np from uniseg.graphemecluster import grapheme_clusters +from .ocr_files import ExtractedText def levenshtein_matrix(seq1: Sequence, seq2: Sequence): """Compute the matrix commonly computed to produce the Levenshtein distance. @@ -75,12 +76,12 @@ def distance(s1, s2): Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme clusters. This should be the correct way to compare two Unicode strings. """ - # XXX - from .cli import ExtractedText + if isinstance(s1, ExtractedText): s1 = s1.text if isinstance(s2, ExtractedText): s2 = s2.text + s1 = list(grapheme_clusters(unicodedata.normalize('NFC', s1))) s2 = list(grapheme_clusters(unicodedata.normalize('NFC', s2))) return levenshtein(s1, s2)