From b14c35e14761f604bdaacf73181b3c4d6da03511 Mon Sep 17 00:00:00 2001 From: "Gerber, Mike" Date: Thu, 8 Oct 2020 12:15:58 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20dinglehopper:=20Use=20multimetho?= =?UTF-8?q?d=20to=20handle=20str=20vs=20ExtractedText?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- qurator/dinglehopper/character_error_rate.py | 16 +++--- qurator/dinglehopper/edit_distance.py | 19 +++---- qurator/dinglehopper/word_error_rate.py | 60 ++++++++++++-------- requirements.txt | 1 + 4 files changed, 54 insertions(+), 42 deletions(-) diff --git a/qurator/dinglehopper/character_error_rate.py b/qurator/dinglehopper/character_error_rate.py index 9f5fda0..998a3c2 100644 --- a/qurator/dinglehopper/character_error_rate.py +++ b/qurator/dinglehopper/character_error_rate.py @@ -3,25 +3,22 @@ from __future__ import division import unicodedata from typing import Tuple +from multimethod import multimethod from uniseg.graphemecluster import grapheme_clusters from qurator.dinglehopper.edit_distance import distance from qurator.dinglehopper.ocr_files import ExtractedText - -def character_error_rate_n(reference, compared) -> Tuple[float, int]: +@multimethod +def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: """ Compute character error rate. :return: character error rate and length of the reference """ - if isinstance(reference, str): - return character_error_rate_n( - ExtractedText.from_str(reference), - compared) d = distance(reference, compared) - n = len(list(grapheme_clusters(reference.text))) + n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference)))) if d == 0: return 0, n @@ -32,6 +29,11 @@ def character_error_rate_n(reference, compared) -> Tuple[float, int]: # XXX Should we really count newlines here? +@multimethod +def character_error_rate_n(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]: + return character_error_rate_n(reference.text, compared.text) + + def character_error_rate(reference, compared) -> float: """ Compute character error rate. diff --git a/qurator/dinglehopper/edit_distance.py b/qurator/dinglehopper/edit_distance.py index 88d3127..ed91443 100644 --- a/qurator/dinglehopper/edit_distance.py +++ b/qurator/dinglehopper/edit_distance.py @@ -5,6 +5,7 @@ from functools import partial, lru_cache from typing import Sequence, Tuple import numpy as np +from multimethod import multimethod from uniseg.graphemecluster import grapheme_clusters from .ocr_files import ExtractedText @@ -70,23 +71,21 @@ def levenshtein_matrix_cache_clear(): _levenshtein_matrix.cache_clear() -def distance(s1, s2): +@multimethod +def distance(s1: str, s2: str): """Compute the Levenshtein edit distance between two Unicode strings Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme clusters. This should be the correct way to compare two Unicode strings. """ + seq1 = list(grapheme_clusters(unicodedata.normalize('NFC', s1))) + seq2 = list(grapheme_clusters(unicodedata.normalize('NFC', s2))) + return levenshtein(seq1, seq2) - # XXX Implicit normalization - if isinstance(s1, str): - s1 = ExtractedText.from_str(s1) - if isinstance(s2, str): - s2 = ExtractedText.from_str(s2) - # s1 and s2 are now guaranteed (by ExtractedText) to be in NFC - seq1 = list(grapheme_clusters(s1.text)) - seq2 = list(grapheme_clusters(s2.text)) - return levenshtein(seq1, seq2) +@multimethod +def distance(s1: ExtractedText, s2: ExtractedText): + return distance(s1.text, s2.text) def seq_editops(seq1, seq2): diff --git a/qurator/dinglehopper/word_error_rate.py b/qurator/dinglehopper/word_error_rate.py index 64eba0a..95ea7f8 100644 --- a/qurator/dinglehopper/word_error_rate.py +++ b/qurator/dinglehopper/word_error_rate.py @@ -1,14 +1,19 @@ from __future__ import division import unicodedata -from typing import Tuple +from typing import Tuple, Iterable +from multimethod import multimethod import uniseg.wordbreak from .edit_distance import levenshtein +from .ocr_files import ExtractedText -def words(s): +@multimethod +def words(s: str): + """Extract words from a string""" + # Patch uniseg.wordbreak.word_break to deal with our private use characters. See also # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt old_word_break = uniseg.wordbreak.word_break @@ -32,11 +37,6 @@ def words(s): cat = subcat[0] return cat in unwanted_categories or subcat in unwanted_subcategories - # XXX - from .cli import ExtractedText - if isinstance(s, ExtractedText): - s = s.text - # We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using # uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters." for word in uniseg.wordbreak.words(s): @@ -46,27 +46,37 @@ def words(s): yield word -def words_normalized(s): - # XXX - from .cli import ExtractedText - if isinstance(s, ExtractedText): - s = s.text +@multimethod +def words(s: ExtractedText): + return words(s.text) + + +@multimethod +def words_normalized(s: str): return words(unicodedata.normalize('NFC', s)) -def word_error_rate_n(reference, compared) -> Tuple[float, int]: - # XXX - from .cli import ExtractedText - if isinstance(reference, ExtractedText): - reference = reference.text - if isinstance(compared, ExtractedText): - compared = compared.text - if isinstance(reference, str): - reference_seq = list(words_normalized(reference)) - compared_seq = list(words_normalized(compared)) - else: - reference_seq = list(reference) - compared_seq = list(compared) +@multimethod +def words_normalized(s: ExtractedText): + return words_normalized(s.text) + + +@multimethod +def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: + reference_seq = list(words_normalized(reference)) + compared_seq = list(words_normalized(compared)) + return word_error_rate_n(reference_seq, compared_seq) + + +@multimethod +def word_error_rate_n(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]: + return word_error_rate_n(reference.text, compared.text) + + +@multimethod +def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]: + reference_seq = list(reference) + compared_seq = list(compared) d = levenshtein(reference_seq, compared_seq) n = len(reference_seq) diff --git a/requirements.txt b/requirements.txt index 846990b..287c959 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ colorama MarkupSafe ocrd >= 1.0.0b15 attrs +multimethod == 1.3 # latest version to officially support Python 3.5 \ No newline at end of file