From 8cd624f79595a02f466ce821b6d5536270993e46 Mon Sep 17 00:00:00 2001
From: Benjamin Rosemann <benjamin.rosemann@la-bw.de>
Date: Tue, 8 Jun 2021 17:41:44 +0200
Subject: [PATCH] Add BoC and BoW metric

Also some refactoring for helper methods on normalization and word
splitting.
---
 qurator/dinglehopper/align.py                 |   5 +-
 qurator/dinglehopper/edit_distance.py         |  15 ++-
 qurator/dinglehopper/metrics/__init__.py      |   3 +
 .../metrics/bag_of_chars_accuracy.py          |  35 ++++++
 .../metrics/bag_of_words_accuracy.py          |  30 +++++
 .../metrics/character_error_rate.py           |   7 +-
 qurator/dinglehopper/metrics/utils.py         |  41 +++++++
 .../dinglehopper/metrics/word_error_rate.py   |  61 +---------
 qurator/dinglehopper/normalize.py             |  61 ++++++++++
 .../tests/metrics/test_bag_accuracy.py        | 104 ++++++++++++++++++
 .../metrics/test_integ_word_error_rate_ocr.py |   5 +-
 .../tests/metrics/test_word_error_rate.py     |   3 +-
 12 files changed, 296 insertions(+), 74 deletions(-)
 create mode 100644 qurator/dinglehopper/metrics/bag_of_chars_accuracy.py
 create mode 100644 qurator/dinglehopper/metrics/bag_of_words_accuracy.py
 create mode 100644 qurator/dinglehopper/metrics/utils.py
 create mode 100644 qurator/dinglehopper/normalize.py
 create mode 100644 qurator/dinglehopper/tests/metrics/test_bag_accuracy.py

diff --git a/qurator/dinglehopper/align.py b/qurator/dinglehopper/align.py
index c7e7733..cc7230b 100644
--- a/qurator/dinglehopper/align.py
+++ b/qurator/dinglehopper/align.py
@@ -1,10 +1,11 @@
 from .edit_distance import *
+from .normalize import chars_normalized
 
 
 def align(t1, t2):
     """Align text."""
-    s1 = list(grapheme_clusters(unicodedata.normalize("NFC", t1)))
-    s2 = list(grapheme_clusters(unicodedata.normalize("NFC", t2)))
+    s1 = chars_normalized(t1)
+    s2 = chars_normalized(t2)
     return seq_align(s1, s2)
 
 
diff --git a/qurator/dinglehopper/edit_distance.py b/qurator/dinglehopper/edit_distance.py
index 0b9c8f4..6c459fa 100644
--- a/qurator/dinglehopper/edit_distance.py
+++ b/qurator/dinglehopper/edit_distance.py
@@ -1,16 +1,15 @@
 from __future__ import division, print_function
 
-import unicodedata
-from functools import partial, lru_cache
+from functools import lru_cache, partial
 from typing import Sequence, Tuple
 
 import numpy as np
 from multimethod import multimethod
-from uniseg.graphemecluster import grapheme_clusters
 from tqdm import tqdm
 
-from .extracted_text import ExtractedText
 from .config import Config
+from .extracted_text import ExtractedText
+from .normalize import chars_normalized
 
 
 def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
@@ -82,8 +81,8 @@ def distance(s1: str, s2: str):
     Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme
     clusters. This should be the correct way to compare two Unicode strings.
     """
-    seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", s1)))
-    seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", s2)))
+    seq1 = chars_normalized(s1)
+    seq2 = chars_normalized(s2)
     return levenshtein(seq1, seq2)
 
 
@@ -139,6 +138,6 @@ def editops(word1, word2):
 
     Note that this returns indices to the _grapheme clusters_, not characters!
     """
-    word1 = list(grapheme_clusters(unicodedata.normalize("NFC", word1)))
-    word2 = list(grapheme_clusters(unicodedata.normalize("NFC", word2)))
+    word1 = chars_normalized(word1)
+    word2 = chars_normalized(word2)
     return seq_editops(word1, word2)
diff --git a/qurator/dinglehopper/metrics/__init__.py b/qurator/dinglehopper/metrics/__init__.py
index 9f370c4..ba9d140 100644
--- a/qurator/dinglehopper/metrics/__init__.py
+++ b/qurator/dinglehopper/metrics/__init__.py
@@ -1,2 +1,5 @@
+from .bag_of_chars_accuracy import *
+from .bag_of_words_accuracy import *
 from .character_error_rate import *
+from .utils import Weights
 from .word_error_rate import *
diff --git a/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py
new file mode 100644
index 0000000..dd6a030
--- /dev/null
+++ b/qurator/dinglehopper/metrics/bag_of_chars_accuracy.py
@@ -0,0 +1,35 @@
+from collections import Counter
+from typing import Tuple, Union
+from unicodedata import normalize
+
+from multimethod import multimethod
+from uniseg.graphemecluster import grapheme_clusters
+
+from .utils import bag_accuracy, Weights
+from .. import ExtractedText
+
+
+def bag_of_chars_accuracy(
+    reference: Union[str, ExtractedText],
+    compared: Union[str, ExtractedText],
+    weights: Weights,
+) -> float:
+    acc, _ = bag_of_chars_accuracy_n(reference, compared, weights)
+    return acc
+
+
+@multimethod
+def bag_of_chars_accuracy_n(
+    reference: str, compared: str, weights: Weights
+) -> Tuple[float, int]:
+    reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
+    compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
+    e, n = bag_accuracy(reference_chars, compared_chars, weights)
+    return (float("inf") if n == 0 else 1 - e / n), n
+
+
+@multimethod
+def bag_of_chars_accuracy_n(
+    reference: ExtractedText, compared: ExtractedText, weights: Weights
+) -> Tuple[float, int]:
+    return bag_of_chars_accuracy_n(reference.text, compared.text, weights)
diff --git a/qurator/dinglehopper/metrics/bag_of_words_accuracy.py b/qurator/dinglehopper/metrics/bag_of_words_accuracy.py
new file mode 100644
index 0000000..7e5f315
--- /dev/null
+++ b/qurator/dinglehopper/metrics/bag_of_words_accuracy.py
@@ -0,0 +1,30 @@
+from collections import Counter
+from typing import Tuple, Union
+
+from .utils import bag_accuracy, Weights
+from .. import ExtractedText
+from ..normalize import words_normalized
+
+
+def bag_of_words_accuracy(
+    reference: Union[str, ExtractedText],
+    compared: Union[str, ExtractedText],
+    weights: Weights,
+) -> float:
+    acc, _ = bag_of_words_accuracy_n(reference, compared, weights)
+    return acc
+
+
+def bag_of_words_accuracy_n(
+    reference: Union[str, ExtractedText],
+    compared: Union[str, ExtractedText],
+    weights: Weights,
+) -> Tuple[float, int]:
+    if isinstance(reference, ExtractedText):
+        reference = reference.text
+    if isinstance(compared, ExtractedText):
+        compared = compared.text
+    reference_words = Counter(words_normalized(reference))
+    compared_words = Counter(words_normalized(compared))
+    e, n = bag_accuracy(reference_words, compared_words, weights)
+    return (float("inf") if n == 0 else 1 - e / n), n
diff --git a/qurator/dinglehopper/metrics/character_error_rate.py b/qurator/dinglehopper/metrics/character_error_rate.py
index 4dae8ee..0e40c66 100644
--- a/qurator/dinglehopper/metrics/character_error_rate.py
+++ b/qurator/dinglehopper/metrics/character_error_rate.py
@@ -1,13 +1,12 @@
 from __future__ import division
 
-import unicodedata
 from typing import Tuple
 
 from multimethod import multimethod
-from uniseg.graphemecluster import grapheme_clusters
 
-from ..edit_distance import distance
+from .. import distance
 from ..extracted_text import ExtractedText
+from ..normalize import chars_normalized
 
 
 @multimethod
@@ -19,7 +18,7 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
     """
 
     d = distance(reference, compared)
-    n = len(list(grapheme_clusters(unicodedata.normalize("NFC", reference))))
+    n = len(chars_normalized(reference))
 
     if d == 0:
         return 0, n
diff --git a/qurator/dinglehopper/metrics/utils.py b/qurator/dinglehopper/metrics/utils.py
new file mode 100644
index 0000000..cfb764e
--- /dev/null
+++ b/qurator/dinglehopper/metrics/utils.py
@@ -0,0 +1,41 @@
+from collections import Counter
+from typing import NamedTuple, Tuple
+
+
+class Weights(NamedTuple):
+    """Represent weights/costs for editing operations."""
+
+    deletes: int = 1
+    inserts: int = 1
+    replacements: int = 1
+
+
+def bag_accuracy(
+    reference: Counter, compared: Counter, weights: Weights
+) -> Tuple[int, int]:
+    """Calculates the the weighted errors for two bags (Counter).
+
+    Basic algorithm idea:
+     - All elements in reference not occurring in compared are considered deletes.
+     - All elements in compared not occurring in reference are considered inserts.
+     - When the cost for one replacement is lower than that of one insert and one delete
+       we can substitute pairs of deletes and inserts with one replacement.
+
+    :param reference: Bag used as reference (ground truth).
+    :param compared: Bag used to compare (ocr).
+    :param weights: Weights/costs for editing operations.
+    :return: weighted errors and number of elements in reference.
+    """
+    n = sum(reference.values())
+    deletes = sum((reference - compared).values())
+    inserts = sum((compared - reference).values())
+    replacements = 0
+    if weights.replacements < (weights.deletes + weights.inserts):
+        replacements = min(deletes, inserts)
+        deletes, inserts = max(deletes - inserts, 0), max(inserts - deletes, 0)
+    weighted_errors = (
+        weights.deletes * deletes
+        + weights.inserts * inserts
+        + weights.replacements * replacements
+    )
+    return weighted_errors, n
diff --git a/qurator/dinglehopper/metrics/word_error_rate.py b/qurator/dinglehopper/metrics/word_error_rate.py
index 5a42eee..14d3784 100644
--- a/qurator/dinglehopper/metrics/word_error_rate.py
+++ b/qurator/dinglehopper/metrics/word_error_rate.py
@@ -1,65 +1,12 @@
 from __future__ import division
 
-import unicodedata
-from typing import Tuple, Iterable
-from multimethod import multimethod
+from typing import Iterable, Tuple
 
-import uniseg.wordbreak
+from multimethod import multimethod
 
 from ..edit_distance import levenshtein
-from .. import ExtractedText
-
-
-@multimethod
-def words(s: str):
-    """Extract words from a string"""
-
-    # Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
-    # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
-    old_word_break = uniseg.wordbreak.word_break
-
-    def new_word_break(c, index=0):
-        if 0xE000 <= ord(c) <= 0xF8FF:  # Private Use Area
-            return "ALetter"
-        else:
-            return old_word_break(c, index)
-
-    uniseg.wordbreak.word_break = new_word_break
-
-    # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
-    def unwanted(c):
-
-        # See https://www.fileformat.info/info/unicode/category/index.htm
-        # and https://unicodebook.readthedocs.io/unicode.html#categories
-        unwanted_categories = "O", "M", "P", "Z", "S"
-        unwanted_subcategories = "Cc", "Cf"
-
-        subcat = unicodedata.category(c)
-        cat = subcat[0]
-        return cat in unwanted_categories or subcat in unwanted_subcategories
-
-    # We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using
-    # uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters."
-    for word in uniseg.wordbreak.words(s):
-        if all(unwanted(c) for c in word):
-            pass
-        else:
-            yield word
-
-
-@multimethod
-def words(s: ExtractedText):
-    return words(s.text)
-
-
-@multimethod
-def words_normalized(s: str):
-    return words(unicodedata.normalize("NFC", s))
-
-
-@multimethod
-def words_normalized(s: ExtractedText):
-    return words_normalized(s.text)
+from ..extracted_text import ExtractedText
+from ..normalize import words_normalized
 
 
 @multimethod
diff --git a/qurator/dinglehopper/normalize.py b/qurator/dinglehopper/normalize.py
new file mode 100644
index 0000000..4ae6617
--- /dev/null
+++ b/qurator/dinglehopper/normalize.py
@@ -0,0 +1,61 @@
+import unicodedata
+from typing import Union
+
+import uniseg.wordbreak
+from uniseg.graphemecluster import grapheme_clusters
+
+from .extracted_text import ExtractedText
+
+
+def chars_normalized(s: Union[str, ExtractedText]):
+    """Normalize characters in string."""
+    if isinstance(s, ExtractedText):
+        s = s.text
+    return list(grapheme_clusters(unicodedata.normalize("NFC", s)))
+
+
+def words(s: Union[str, ExtractedText]):
+    """Extract words from a string"""
+
+    if isinstance(s, ExtractedText):
+        s = s.text
+
+    # Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
+    # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
+    old_word_break = uniseg.wordbreak.word_break
+
+    def new_word_break(c, index=0):
+        if 0xE000 <= ord(c) <= 0xF8FF:  # Private Use Area
+            return "ALetter"
+        else:
+            return old_word_break(c, index)
+
+    uniseg.wordbreak.word_break = new_word_break
+
+    # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
+    def unwanted(c):
+
+        # See https://www.fileformat.info/info/unicode/category/index.htm
+        # and https://unicodebook.readthedocs.io/unicode.html#categories
+        unwanted_categories = "O", "M", "P", "Z", "S"
+        unwanted_subcategories = "Cc", "Cf"
+
+        subcat = unicodedata.category(c)
+        cat = subcat[0]
+        return cat in unwanted_categories or subcat in unwanted_subcategories
+
+    # We follow Unicode Standard Annex #29 on Unicode Text Segmentation here:
+    # Split on word boundaries using uniseg.wordbreak.words() and ignore all
+    # "words" that contain only whitespace, punctuation "or similar characters."
+    for word in uniseg.wordbreak.words(s):
+        if all(unwanted(c) for c in word):
+            pass
+        else:
+            yield word
+
+
+def words_normalized(s: Union[str, ExtractedText]):
+    """Extract words from string and normalize them."""
+    if isinstance(s, ExtractedText):
+        s = s.text
+    return words(unicodedata.normalize("NFC", s))
diff --git a/qurator/dinglehopper/tests/metrics/test_bag_accuracy.py b/qurator/dinglehopper/tests/metrics/test_bag_accuracy.py
new file mode 100644
index 0000000..345e0bd
--- /dev/null
+++ b/qurator/dinglehopper/tests/metrics/test_bag_accuracy.py
@@ -0,0 +1,104 @@
+import math
+import unicodedata
+from collections import Counter
+
+import pytest
+
+from ...metrics import bag_of_chars_accuracy_n, bag_of_words_accuracy_n, Weights
+from ...metrics.utils import bag_accuracy
+
+
+@pytest.fixture
+def ex_weights():
+    return (
+        Weights(deletes=0, inserts=0, replacements=0),
+        Weights(deletes=1, inserts=1, replacements=1),
+        Weights(deletes=1, inserts=0, replacements=1),
+        Weights(deletes=1, inserts=1, replacements=2),
+    )
+
+
+SIMPLE_CASES = (
+    ("", "", 0, (0, 0, 0)),
+    ("abc", "", 3, (3, 3, 3)),
+    ("", "abc", 0, (3, 0, 3)),
+    ("abc", "abc", 3, (0, 0, 0)),
+    ("abc", "ab", 3, (1, 1, 1)),
+    ("abc", "abcd", 3, (1, 0, 1)),
+    ("abc", "abd", 3, (1, 1, 2)),
+)
+
+
+@pytest.mark.parametrize(
+    "s1,s2, ex_n, ex_err",
+    [
+        *SIMPLE_CASES,
+        (("a", "b", "c", "d", "e"), ("a", "b", "c", "d", ("e", "´")), 5, (1, 1, 2)),
+        (range(5), range(6), 5, (1, 0, 1)),
+    ],
+)
+def test_bag_accuracy_algorithm(s1, s2, ex_n, ex_err, ex_weights):
+    """Test the main algorithm for calculating the bag accuracy."""
+    for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
+        e, n = bag_accuracy(Counter(s1), Counter(s2), weights=weights)
+        assert n == ex_n, f"{n} == {ex_n} for {weights}"
+        assert e == expected_errors, f"{e} == {expected_errors} for {weights}"
+
+
+@pytest.mark.parametrize(
+    "s1,s2, ex_n, ex_err",
+    [
+        *SIMPLE_CASES,
+        ("Schlyñ", "Schlym̃", 6, (1, 1, 2)),
+        (
+            unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
+            unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
+            19,
+            (1, 1, 2),
+        ),
+    ],
+)
+def test_bag_of_chars_accuracy_n(s1, s2, ex_n, ex_err, ex_weights):
+    """Test the special behaviour of the char differentiation.
+
+    As the algorithm and the char normalization is implemented elsewhere
+    we are currently only testing that the corresponding algorithms are called.
+    """
+    for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
+        acc, n = bag_of_chars_accuracy_n(s1, s2, weights)
+        assert n == ex_n, f"{n} == {ex_n} for {weights}"
+        if ex_n == 0:
+            assert math.isinf(acc)
+        else:
+            assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}"
+
+
+@pytest.mark.parametrize(
+    "s1,s2, ex_n, ex_err",
+    [
+        *SIMPLE_CASES,
+        ("Schlyñ", "Schlym̃", 6, (1, 1, 2)),
+        (
+            unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
+            unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
+            3,
+            (0, 0, 0),
+        ),
+    ],
+)
+def test_bag_of_words_accuracy_n(s1, s2, ex_n, ex_err, ex_weights):
+    """Test the special behaviour of the word differentiation.
+
+    As the algorithm and the word splitting is implemented elsewhere
+    we are currently only testing that the corresponding algorithms are called.
+    """
+    if " " not in s1 and " " not in s2:
+        s1 = " ".join(s1)
+        s2 = " ".join(s2)
+    for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
+        acc, n = bag_of_words_accuracy_n(s1, s2, weights)
+        assert n == ex_n, f"{n} == {ex_n} for {weights}"
+        if ex_n == 0:
+            assert math.isinf(acc)
+        else:
+            assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}"
diff --git a/qurator/dinglehopper/tests/metrics/test_integ_word_error_rate_ocr.py b/qurator/dinglehopper/tests/metrics/test_integ_word_error_rate_ocr.py
index 1b8dd7e..9654061 100644
--- a/qurator/dinglehopper/tests/metrics/test_integ_word_error_rate_ocr.py
+++ b/qurator/dinglehopper/tests/metrics/test_integ_word_error_rate_ocr.py
@@ -5,8 +5,9 @@ import os
 import pytest
 from lxml import etree as ET
 
-from ... import page_text, alto_text
-from ...metrics import word_error_rate, words\
+from ... import alto_text, page_text
+from ...metrics import word_error_rate
+from ...normalize import words
 
 data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
 
diff --git a/qurator/dinglehopper/tests/metrics/test_word_error_rate.py b/qurator/dinglehopper/tests/metrics/test_word_error_rate.py
index 36f2823..7e7d392 100644
--- a/qurator/dinglehopper/tests/metrics/test_word_error_rate.py
+++ b/qurator/dinglehopper/tests/metrics/test_word_error_rate.py
@@ -2,7 +2,8 @@ from __future__ import division, print_function
 
 import math
 
-from ...metrics import word_error_rate, words
+from ...metrics import word_error_rate
+from ...normalize import words
 
 
 def test_words():