🎨 dinglehopper: Use multimethod to handle str vs ExtractedText

pull/38/head
Gerber, Mike 4 years ago
parent a17ee2afec
commit b14c35e147

@ -3,25 +3,22 @@ from __future__ import division
import unicodedata
from typing import Tuple
from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters
from qurator.dinglehopper.edit_distance import distance
from qurator.dinglehopper.ocr_files import ExtractedText
def character_error_rate_n(reference, compared) -> Tuple[float, int]:
@multimethod
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
"""
Compute character error rate.
:return: character error rate and length of the reference
"""
if isinstance(reference, str):
return character_error_rate_n(
ExtractedText.from_str(reference),
compared)
d = distance(reference, compared)
n = len(list(grapheme_clusters(reference.text)))
n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference))))
if d == 0:
return 0, n
@ -32,6 +29,11 @@ def character_error_rate_n(reference, compared) -> Tuple[float, int]:
# XXX Should we really count newlines here?
@multimethod
def character_error_rate_n(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return character_error_rate_n(reference.text, compared.text)
def character_error_rate(reference, compared) -> float:
"""
Compute character error rate.

@ -5,6 +5,7 @@ from functools import partial, lru_cache
from typing import Sequence, Tuple
import numpy as np
from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters
from .ocr_files import ExtractedText
@ -70,23 +71,21 @@ def levenshtein_matrix_cache_clear():
_levenshtein_matrix.cache_clear()
def distance(s1, s2):
@multimethod
def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme
clusters. This should be the correct way to compare two Unicode strings.
"""
seq1 = list(grapheme_clusters(unicodedata.normalize('NFC', s1)))
seq2 = list(grapheme_clusters(unicodedata.normalize('NFC', s2)))
return levenshtein(seq1, seq2)
# XXX Implicit normalization
if isinstance(s1, str):
s1 = ExtractedText.from_str(s1)
if isinstance(s2, str):
s2 = ExtractedText.from_str(s2)
# s1 and s2 are now guaranteed (by ExtractedText) to be in NFC
seq1 = list(grapheme_clusters(s1.text))
seq2 = list(grapheme_clusters(s2.text))
return levenshtein(seq1, seq2)
@multimethod
def distance(s1: ExtractedText, s2: ExtractedText):
return distance(s1.text, s2.text)
def seq_editops(seq1, seq2):

@ -1,14 +1,19 @@
from __future__ import division
import unicodedata
from typing import Tuple
from typing import Tuple, Iterable
from multimethod import multimethod
import uniseg.wordbreak
from .edit_distance import levenshtein
from .ocr_files import ExtractedText
def words(s):
@multimethod
def words(s: str):
"""Extract words from a string"""
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
old_word_break = uniseg.wordbreak.word_break
@ -32,11 +37,6 @@ def words(s):
cat = subcat[0]
return cat in unwanted_categories or subcat in unwanted_subcategories
# XXX
from .cli import ExtractedText
if isinstance(s, ExtractedText):
s = s.text
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using
# uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters."
for word in uniseg.wordbreak.words(s):
@ -46,27 +46,37 @@ def words(s):
yield word
def words_normalized(s):
# XXX
from .cli import ExtractedText
if isinstance(s, ExtractedText):
s = s.text
@multimethod
def words(s: ExtractedText):
return words(s.text)
@multimethod
def words_normalized(s: str):
return words(unicodedata.normalize('NFC', s))
def word_error_rate_n(reference, compared) -> Tuple[float, int]:
# XXX
from .cli import ExtractedText
if isinstance(reference, ExtractedText):
reference = reference.text
if isinstance(compared, ExtractedText):
compared = compared.text
if isinstance(reference, str):
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
else:
reference_seq = list(reference)
compared_seq = list(compared)
@multimethod
def words_normalized(s: ExtractedText):
return words_normalized(s.text)
@multimethod
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
return word_error_rate_n(reference_seq, compared_seq)
@multimethod
def word_error_rate_n(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text)
@multimethod
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)
d = levenshtein(reference_seq, compared_seq)
n = len(reference_seq)

@ -7,3 +7,4 @@ colorama
MarkupSafe
ocrd >= 1.0.0b15
attrs
multimethod == 1.3 # latest version to officially support Python 3.5
Loading…
Cancel
Save