diff --git a/qurator/dinglehopper/align.py b/qurator/dinglehopper/align.py index c7e7733..cc7230b 100644 --- a/qurator/dinglehopper/align.py +++ b/qurator/dinglehopper/align.py @@ -1,10 +1,11 @@ from .edit_distance import * +from .normalize import chars_normalized def align(t1, t2): """Align text.""" - s1 = list(grapheme_clusters(unicodedata.normalize("NFC", t1))) - s2 = list(grapheme_clusters(unicodedata.normalize("NFC", t2))) + s1 = chars_normalized(t1) + s2 = chars_normalized(t2) return seq_align(s1, s2) diff --git a/qurator/dinglehopper/edit_distance.py b/qurator/dinglehopper/edit_distance.py index 0b9c8f4..6c459fa 100644 --- a/qurator/dinglehopper/edit_distance.py +++ b/qurator/dinglehopper/edit_distance.py @@ -1,16 +1,15 @@ from __future__ import division, print_function -import unicodedata -from functools import partial, lru_cache +from functools import lru_cache, partial from typing import Sequence, Tuple import numpy as np from multimethod import multimethod -from uniseg.graphemecluster import grapheme_clusters from tqdm import tqdm -from .extracted_text import ExtractedText from .config import Config +from .extracted_text import ExtractedText +from .normalize import chars_normalized def levenshtein_matrix(seq1: Sequence, seq2: Sequence): @@ -82,8 +81,8 @@ def distance(s1: str, s2: str): Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme clusters. This should be the correct way to compare two Unicode strings. """ - seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", s1))) - seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", s2))) + seq1 = chars_normalized(s1) + seq2 = chars_normalized(s2) return levenshtein(seq1, seq2) @@ -139,6 +138,6 @@ def editops(word1, word2): Note that this returns indices to the _grapheme clusters_, not characters! """ - word1 = list(grapheme_clusters(unicodedata.normalize("NFC", word1))) - word2 = list(grapheme_clusters(unicodedata.normalize("NFC", word2))) + word1 = chars_normalized(word1) + word2 = chars_normalized(word2) return seq_editops(word1, word2) diff --git a/qurator/dinglehopper/metrics/__init__.py b/qurator/dinglehopper/metrics/__init__.py index 9f370c4..ba9d140 100644 --- a/qurator/dinglehopper/metrics/__init__.py +++ b/qurator/dinglehopper/metrics/__init__.py @@ -1,2 +1,5 @@ +from .bag_of_chars_accuracy import * +from .bag_of_words_accuracy import * from .character_error_rate import * +from .utils import Weights from .word_error_rate import * diff --git a/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py new file mode 100644 index 0000000..dd6a030 --- /dev/null +++ b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py @@ -0,0 +1,35 @@ +from collections import Counter +from typing import Tuple, Union +from unicodedata import normalize + +from multimethod import multimethod +from uniseg.graphemecluster import grapheme_clusters + +from .utils import bag_accuracy, Weights +from .. import ExtractedText + + +def bag_of_chars_accuracy( + reference: Union[str, ExtractedText], + compared: Union[str, ExtractedText], + weights: Weights, +) -> float: + acc, _ = bag_of_chars_accuracy_n(reference, compared, weights) + return acc + + +@multimethod +def bag_of_chars_accuracy_n( + reference: str, compared: str, weights: Weights +) -> Tuple[float, int]: + reference_chars = Counter(grapheme_clusters(normalize("NFC", reference))) + compared_chars = Counter(grapheme_clusters(normalize("NFC", compared))) + e, n = bag_accuracy(reference_chars, compared_chars, weights) + return (float("inf") if n == 0 else 1 - e / n), n + + +@multimethod +def bag_of_chars_accuracy_n( + reference: ExtractedText, compared: ExtractedText, weights: Weights +) -> Tuple[float, int]: + return bag_of_chars_accuracy_n(reference.text, compared.text, weights) diff --git a/qurator/dinglehopper/metrics/bag_of_words_accuracy.py b/qurator/dinglehopper/metrics/bag_of_words_accuracy.py new file mode 100644 index 0000000..7e5f315 --- /dev/null +++ b/qurator/dinglehopper/metrics/bag_of_words_accuracy.py @@ -0,0 +1,30 @@ +from collections import Counter +from typing import Tuple, Union + +from .utils import bag_accuracy, Weights +from .. import ExtractedText +from ..normalize import words_normalized + + +def bag_of_words_accuracy( + reference: Union[str, ExtractedText], + compared: Union[str, ExtractedText], + weights: Weights, +) -> float: + acc, _ = bag_of_words_accuracy_n(reference, compared, weights) + return acc + + +def bag_of_words_accuracy_n( + reference: Union[str, ExtractedText], + compared: Union[str, ExtractedText], + weights: Weights, +) -> Tuple[float, int]: + if isinstance(reference, ExtractedText): + reference = reference.text + if isinstance(compared, ExtractedText): + compared = compared.text + reference_words = Counter(words_normalized(reference)) + compared_words = Counter(words_normalized(compared)) + e, n = bag_accuracy(reference_words, compared_words, weights) + return (float("inf") if n == 0 else 1 - e / n), n diff --git a/qurator/dinglehopper/metrics/character_error_rate.py b/qurator/dinglehopper/metrics/character_error_rate.py index 4dae8ee..0e40c66 100644 --- a/qurator/dinglehopper/metrics/character_error_rate.py +++ b/qurator/dinglehopper/metrics/character_error_rate.py @@ -1,13 +1,12 @@ from __future__ import division -import unicodedata from typing import Tuple from multimethod import multimethod -from uniseg.graphemecluster import grapheme_clusters -from ..edit_distance import distance +from .. import distance from ..extracted_text import ExtractedText +from ..normalize import chars_normalized @multimethod @@ -19,7 +18,7 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: """ d = distance(reference, compared) - n = len(list(grapheme_clusters(unicodedata.normalize("NFC", reference)))) + n = len(chars_normalized(reference)) if d == 0: return 0, n diff --git a/qurator/dinglehopper/metrics/utils.py b/qurator/dinglehopper/metrics/utils.py new file mode 100644 index 0000000..cfb764e --- /dev/null +++ b/qurator/dinglehopper/metrics/utils.py @@ -0,0 +1,41 @@ +from collections import Counter +from typing import NamedTuple, Tuple + + +class Weights(NamedTuple): + """Represent weights/costs for editing operations.""" + + deletes: int = 1 + inserts: int = 1 + replacements: int = 1 + + +def bag_accuracy( + reference: Counter, compared: Counter, weights: Weights +) -> Tuple[int, int]: + """Calculates the the weighted errors for two bags (Counter). + + Basic algorithm idea: + - All elements in reference not occurring in compared are considered deletes. + - All elements in compared not occurring in reference are considered inserts. + - When the cost for one replacement is lower than that of one insert and one delete + we can substitute pairs of deletes and inserts with one replacement. + + :param reference: Bag used as reference (ground truth). + :param compared: Bag used to compare (ocr). + :param weights: Weights/costs for editing operations. + :return: weighted errors and number of elements in reference. + """ + n = sum(reference.values()) + deletes = sum((reference - compared).values()) + inserts = sum((compared - reference).values()) + replacements = 0 + if weights.replacements < (weights.deletes + weights.inserts): + replacements = min(deletes, inserts) + deletes, inserts = max(deletes - inserts, 0), max(inserts - deletes, 0) + weighted_errors = ( + weights.deletes * deletes + + weights.inserts * inserts + + weights.replacements * replacements + ) + return weighted_errors, n diff --git a/qurator/dinglehopper/metrics/word_error_rate.py b/qurator/dinglehopper/metrics/word_error_rate.py index 5a42eee..14d3784 100644 --- a/qurator/dinglehopper/metrics/word_error_rate.py +++ b/qurator/dinglehopper/metrics/word_error_rate.py @@ -1,65 +1,12 @@ from __future__ import division -import unicodedata -from typing import Tuple, Iterable -from multimethod import multimethod +from typing import Iterable, Tuple -import uniseg.wordbreak +from multimethod import multimethod from ..edit_distance import levenshtein -from .. import ExtractedText - - -@multimethod -def words(s: str): - """Extract words from a string""" - - # Patch uniseg.wordbreak.word_break to deal with our private use characters. See also - # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt - old_word_break = uniseg.wordbreak.word_break - - def new_word_break(c, index=0): - if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area - return "ALetter" - else: - return old_word_break(c, index) - - uniseg.wordbreak.word_break = new_word_break - - # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar - def unwanted(c): - - # See https://www.fileformat.info/info/unicode/category/index.htm - # and https://unicodebook.readthedocs.io/unicode.html#categories - unwanted_categories = "O", "M", "P", "Z", "S" - unwanted_subcategories = "Cc", "Cf" - - subcat = unicodedata.category(c) - cat = subcat[0] - return cat in unwanted_categories or subcat in unwanted_subcategories - - # We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using - # uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters." - for word in uniseg.wordbreak.words(s): - if all(unwanted(c) for c in word): - pass - else: - yield word - - -@multimethod -def words(s: ExtractedText): - return words(s.text) - - -@multimethod -def words_normalized(s: str): - return words(unicodedata.normalize("NFC", s)) - - -@multimethod -def words_normalized(s: ExtractedText): - return words_normalized(s.text) +from ..extracted_text import ExtractedText +from ..normalize import words_normalized @multimethod diff --git a/qurator/dinglehopper/normalize.py b/qurator/dinglehopper/normalize.py new file mode 100644 index 0000000..4ae6617 --- /dev/null +++ b/qurator/dinglehopper/normalize.py @@ -0,0 +1,61 @@ +import unicodedata +from typing import Union + +import uniseg.wordbreak +from uniseg.graphemecluster import grapheme_clusters + +from .extracted_text import ExtractedText + + +def chars_normalized(s: Union[str, ExtractedText]): + """Normalize characters in string.""" + if isinstance(s, ExtractedText): + s = s.text + return list(grapheme_clusters(unicodedata.normalize("NFC", s))) + + +def words(s: Union[str, ExtractedText]): + """Extract words from a string""" + + if isinstance(s, ExtractedText): + s = s.text + + # Patch uniseg.wordbreak.word_break to deal with our private use characters. See also + # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt + old_word_break = uniseg.wordbreak.word_break + + def new_word_break(c, index=0): + if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area + return "ALetter" + else: + return old_word_break(c, index) + + uniseg.wordbreak.word_break = new_word_break + + # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar + def unwanted(c): + + # See https://www.fileformat.info/info/unicode/category/index.htm + # and https://unicodebook.readthedocs.io/unicode.html#categories + unwanted_categories = "O", "M", "P", "Z", "S" + unwanted_subcategories = "Cc", "Cf" + + subcat = unicodedata.category(c) + cat = subcat[0] + return cat in unwanted_categories or subcat in unwanted_subcategories + + # We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: + # Split on word boundaries using uniseg.wordbreak.words() and ignore all + # "words" that contain only whitespace, punctuation "or similar characters." + for word in uniseg.wordbreak.words(s): + if all(unwanted(c) for c in word): + pass + else: + yield word + + +def words_normalized(s: Union[str, ExtractedText]): + """Extract words from string and normalize them.""" + if isinstance(s, ExtractedText): + s = s.text + return words(unicodedata.normalize("NFC", s)) diff --git a/qurator/dinglehopper/tests/metrics/test_bag_accuracy.py b/qurator/dinglehopper/tests/metrics/test_bag_accuracy.py new file mode 100644 index 0000000..345e0bd --- /dev/null +++ b/qurator/dinglehopper/tests/metrics/test_bag_accuracy.py @@ -0,0 +1,104 @@ +import math +import unicodedata +from collections import Counter + +import pytest + +from ...metrics import bag_of_chars_accuracy_n, bag_of_words_accuracy_n, Weights +from ...metrics.utils import bag_accuracy + + +@pytest.fixture +def ex_weights(): + return ( + Weights(deletes=0, inserts=0, replacements=0), + Weights(deletes=1, inserts=1, replacements=1), + Weights(deletes=1, inserts=0, replacements=1), + Weights(deletes=1, inserts=1, replacements=2), + ) + + +SIMPLE_CASES = ( + ("", "", 0, (0, 0, 0)), + ("abc", "", 3, (3, 3, 3)), + ("", "abc", 0, (3, 0, 3)), + ("abc", "abc", 3, (0, 0, 0)), + ("abc", "ab", 3, (1, 1, 1)), + ("abc", "abcd", 3, (1, 0, 1)), + ("abc", "abd", 3, (1, 1, 2)), +) + + +@pytest.mark.parametrize( + "s1,s2, ex_n, ex_err", + [ + *SIMPLE_CASES, + (("a", "b", "c", "d", "e"), ("a", "b", "c", "d", ("e", "´")), 5, (1, 1, 2)), + (range(5), range(6), 5, (1, 0, 1)), + ], +) +def test_bag_accuracy_algorithm(s1, s2, ex_n, ex_err, ex_weights): + """Test the main algorithm for calculating the bag accuracy.""" + for weights, expected_errors in zip(ex_weights, (0, *ex_err)): + e, n = bag_accuracy(Counter(s1), Counter(s2), weights=weights) + assert n == ex_n, f"{n} == {ex_n} for {weights}" + assert e == expected_errors, f"{e} == {expected_errors} for {weights}" + + +@pytest.mark.parametrize( + "s1,s2, ex_n, ex_err", + [ + *SIMPLE_CASES, + ("Schlyñ", "Schlym̃", 6, (1, 1, 2)), + ( + unicodedata.normalize("NFC", "Schlyñ lorem ipsum."), + unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"), + 19, + (1, 1, 2), + ), + ], +) +def test_bag_of_chars_accuracy_n(s1, s2, ex_n, ex_err, ex_weights): + """Test the special behaviour of the char differentiation. + + As the algorithm and the char normalization is implemented elsewhere + we are currently only testing that the corresponding algorithms are called. + """ + for weights, expected_errors in zip(ex_weights, (0, *ex_err)): + acc, n = bag_of_chars_accuracy_n(s1, s2, weights) + assert n == ex_n, f"{n} == {ex_n} for {weights}" + if ex_n == 0: + assert math.isinf(acc) + else: + assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}" + + +@pytest.mark.parametrize( + "s1,s2, ex_n, ex_err", + [ + *SIMPLE_CASES, + ("Schlyñ", "Schlym̃", 6, (1, 1, 2)), + ( + unicodedata.normalize("NFC", "Schlyñ lorem ipsum."), + unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"), + 3, + (0, 0, 0), + ), + ], +) +def test_bag_of_words_accuracy_n(s1, s2, ex_n, ex_err, ex_weights): + """Test the special behaviour of the word differentiation. + + As the algorithm and the word splitting is implemented elsewhere + we are currently only testing that the corresponding algorithms are called. + """ + if " " not in s1 and " " not in s2: + s1 = " ".join(s1) + s2 = " ".join(s2) + for weights, expected_errors in zip(ex_weights, (0, *ex_err)): + acc, n = bag_of_words_accuracy_n(s1, s2, weights) + assert n == ex_n, f"{n} == {ex_n} for {weights}" + if ex_n == 0: + assert math.isinf(acc) + else: + assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}" diff --git a/qurator/dinglehopper/tests/metrics/test_integ_word_error_rate_ocr.py b/qurator/dinglehopper/tests/metrics/test_integ_word_error_rate_ocr.py index 1b8dd7e..9654061 100644 --- a/qurator/dinglehopper/tests/metrics/test_integ_word_error_rate_ocr.py +++ b/qurator/dinglehopper/tests/metrics/test_integ_word_error_rate_ocr.py @@ -5,8 +5,9 @@ import os import pytest from lxml import etree as ET -from ... import page_text, alto_text -from ...metrics import word_error_rate, words\ +from ... import alto_text, page_text +from ...metrics import word_error_rate +from ...normalize import words data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data") diff --git a/qurator/dinglehopper/tests/metrics/test_word_error_rate.py b/qurator/dinglehopper/tests/metrics/test_word_error_rate.py index 36f2823..7e7d392 100644 --- a/qurator/dinglehopper/tests/metrics/test_word_error_rate.py +++ b/qurator/dinglehopper/tests/metrics/test_word_error_rate.py @@ -2,7 +2,8 @@ from __future__ import division, print_function import math -from ...metrics import word_error_rate, words +from ...metrics import word_error_rate +from ...normalize import words def test_words():