🎨 dinglehopper: Use multimethod to handle str vs ExtractedText

pull/38/head
Gerber, Mike 4 years ago
parent a17ee2afec
commit b14c35e147

@ -3,25 +3,22 @@ from __future__ import division
import unicodedata import unicodedata
from typing import Tuple from typing import Tuple
from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
from qurator.dinglehopper.edit_distance import distance from qurator.dinglehopper.edit_distance import distance
from qurator.dinglehopper.ocr_files import ExtractedText from qurator.dinglehopper.ocr_files import ExtractedText
@multimethod
def character_error_rate_n(reference, compared) -> Tuple[float, int]: def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
""" """
Compute character error rate. Compute character error rate.
:return: character error rate and length of the reference :return: character error rate and length of the reference
""" """
if isinstance(reference, str):
return character_error_rate_n(
ExtractedText.from_str(reference),
compared)
d = distance(reference, compared) d = distance(reference, compared)
n = len(list(grapheme_clusters(reference.text))) n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference))))
if d == 0: if d == 0:
return 0, n return 0, n
@ -32,6 +29,11 @@ def character_error_rate_n(reference, compared) -> Tuple[float, int]:
# XXX Should we really count newlines here? # XXX Should we really count newlines here?
@multimethod
def character_error_rate_n(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return character_error_rate_n(reference.text, compared.text)
def character_error_rate(reference, compared) -> float: def character_error_rate(reference, compared) -> float:
""" """
Compute character error rate. Compute character error rate.

@ -5,6 +5,7 @@ from functools import partial, lru_cache
from typing import Sequence, Tuple from typing import Sequence, Tuple
import numpy as np import numpy as np
from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
from .ocr_files import ExtractedText from .ocr_files import ExtractedText
@ -70,23 +71,21 @@ def levenshtein_matrix_cache_clear():
_levenshtein_matrix.cache_clear() _levenshtein_matrix.cache_clear()
def distance(s1, s2): @multimethod
def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings """Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme
clusters. This should be the correct way to compare two Unicode strings. clusters. This should be the correct way to compare two Unicode strings.
""" """
seq1 = list(grapheme_clusters(unicodedata.normalize('NFC', s1)))
seq2 = list(grapheme_clusters(unicodedata.normalize('NFC', s2)))
return levenshtein(seq1, seq2)
# XXX Implicit normalization
if isinstance(s1, str):
s1 = ExtractedText.from_str(s1)
if isinstance(s2, str):
s2 = ExtractedText.from_str(s2)
# s1 and s2 are now guaranteed (by ExtractedText) to be in NFC
seq1 = list(grapheme_clusters(s1.text)) @multimethod
seq2 = list(grapheme_clusters(s2.text)) def distance(s1: ExtractedText, s2: ExtractedText):
return levenshtein(seq1, seq2) return distance(s1.text, s2.text)
def seq_editops(seq1, seq2): def seq_editops(seq1, seq2):

@ -1,14 +1,19 @@
from __future__ import division from __future__ import division
import unicodedata import unicodedata
from typing import Tuple from typing import Tuple, Iterable
from multimethod import multimethod
import uniseg.wordbreak import uniseg.wordbreak
from .edit_distance import levenshtein from .edit_distance import levenshtein
from .ocr_files import ExtractedText
def words(s): @multimethod
def words(s: str):
"""Extract words from a string"""
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also # Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
old_word_break = uniseg.wordbreak.word_break old_word_break = uniseg.wordbreak.word_break
@ -32,11 +37,6 @@ def words(s):
cat = subcat[0] cat = subcat[0]
return cat in unwanted_categories or subcat in unwanted_subcategories return cat in unwanted_categories or subcat in unwanted_subcategories
# XXX
from .cli import ExtractedText
if isinstance(s, ExtractedText):
s = s.text
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using # We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using
# uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters." # uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters."
for word in uniseg.wordbreak.words(s): for word in uniseg.wordbreak.words(s):
@ -46,27 +46,37 @@ def words(s):
yield word yield word
def words_normalized(s): @multimethod
# XXX def words(s: ExtractedText):
from .cli import ExtractedText return words(s.text)
if isinstance(s, ExtractedText):
s = s.text
@multimethod
def words_normalized(s: str):
return words(unicodedata.normalize('NFC', s)) return words(unicodedata.normalize('NFC', s))
def word_error_rate_n(reference, compared) -> Tuple[float, int]: @multimethod
# XXX def words_normalized(s: ExtractedText):
from .cli import ExtractedText return words_normalized(s.text)
if isinstance(reference, ExtractedText):
reference = reference.text
if isinstance(compared, ExtractedText): @multimethod
compared = compared.text def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
if isinstance(reference, str): reference_seq = list(words_normalized(reference))
reference_seq = list(words_normalized(reference)) compared_seq = list(words_normalized(compared))
compared_seq = list(words_normalized(compared)) return word_error_rate_n(reference_seq, compared_seq)
else:
reference_seq = list(reference)
compared_seq = list(compared) @multimethod
def word_error_rate_n(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text)
@multimethod
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)
d = levenshtein(reference_seq, compared_seq) d = levenshtein(reference_seq, compared_seq)
n = len(reference_seq) n = len(reference_seq)

@ -7,3 +7,4 @@ colorama
MarkupSafe MarkupSafe
ocrd >= 1.0.0b15 ocrd >= 1.0.0b15
attrs attrs
multimethod == 1.3 # latest version to officially support Python 3.5
Loading…
Cancel
Save