mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-09 03:40:12 +02:00
Add BoC and BoW metric
Also some refactoring for helper methods on normalization and word splitting.
This commit is contained in:
parent
4ccae9432d
commit
8cd624f795
12 changed files with 296 additions and 74 deletions
|
@ -1,10 +1,11 @@
|
|||
from .edit_distance import *
|
||||
from .normalize import chars_normalized
|
||||
|
||||
|
||||
def align(t1, t2):
|
||||
"""Align text."""
|
||||
s1 = list(grapheme_clusters(unicodedata.normalize("NFC", t1)))
|
||||
s2 = list(grapheme_clusters(unicodedata.normalize("NFC", t2)))
|
||||
s1 = chars_normalized(t1)
|
||||
s2 = chars_normalized(t2)
|
||||
return seq_align(s1, s2)
|
||||
|
||||
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
from __future__ import division, print_function
|
||||
|
||||
import unicodedata
|
||||
from functools import partial, lru_cache
|
||||
from functools import lru_cache, partial
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
from multimethod import multimethod
|
||||
from uniseg.graphemecluster import grapheme_clusters
|
||||
from tqdm import tqdm
|
||||
|
||||
from .extracted_text import ExtractedText
|
||||
from .config import Config
|
||||
from .extracted_text import ExtractedText
|
||||
from .normalize import chars_normalized
|
||||
|
||||
|
||||
def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
|
||||
|
@ -82,8 +81,8 @@ def distance(s1: str, s2: str):
|
|||
Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme
|
||||
clusters. This should be the correct way to compare two Unicode strings.
|
||||
"""
|
||||
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", s1)))
|
||||
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", s2)))
|
||||
seq1 = chars_normalized(s1)
|
||||
seq2 = chars_normalized(s2)
|
||||
return levenshtein(seq1, seq2)
|
||||
|
||||
|
||||
|
@ -139,6 +138,6 @@ def editops(word1, word2):
|
|||
|
||||
Note that this returns indices to the _grapheme clusters_, not characters!
|
||||
"""
|
||||
word1 = list(grapheme_clusters(unicodedata.normalize("NFC", word1)))
|
||||
word2 = list(grapheme_clusters(unicodedata.normalize("NFC", word2)))
|
||||
word1 = chars_normalized(word1)
|
||||
word2 = chars_normalized(word2)
|
||||
return seq_editops(word1, word2)
|
||||
|
|
|
@ -1,2 +1,5 @@
|
|||
from .bag_of_chars_accuracy import *
|
||||
from .bag_of_words_accuracy import *
|
||||
from .character_error_rate import *
|
||||
from .utils import Weights
|
||||
from .word_error_rate import *
|
||||
|
|
35
qurator/dinglehopper/metrics/bag_of_chars_accuracy.py
Normal file
35
qurator/dinglehopper/metrics/bag_of_chars_accuracy.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
from collections import Counter
|
||||
from typing import Tuple, Union
|
||||
from unicodedata import normalize
|
||||
|
||||
from multimethod import multimethod
|
||||
from uniseg.graphemecluster import grapheme_clusters
|
||||
|
||||
from .utils import bag_accuracy, Weights
|
||||
from .. import ExtractedText
|
||||
|
||||
|
||||
def bag_of_chars_accuracy(
|
||||
reference: Union[str, ExtractedText],
|
||||
compared: Union[str, ExtractedText],
|
||||
weights: Weights,
|
||||
) -> float:
|
||||
acc, _ = bag_of_chars_accuracy_n(reference, compared, weights)
|
||||
return acc
|
||||
|
||||
|
||||
@multimethod
|
||||
def bag_of_chars_accuracy_n(
|
||||
reference: str, compared: str, weights: Weights
|
||||
) -> Tuple[float, int]:
|
||||
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
|
||||
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
|
||||
e, n = bag_accuracy(reference_chars, compared_chars, weights)
|
||||
return (float("inf") if n == 0 else 1 - e / n), n
|
||||
|
||||
|
||||
@multimethod
|
||||
def bag_of_chars_accuracy_n(
|
||||
reference: ExtractedText, compared: ExtractedText, weights: Weights
|
||||
) -> Tuple[float, int]:
|
||||
return bag_of_chars_accuracy_n(reference.text, compared.text, weights)
|
30
qurator/dinglehopper/metrics/bag_of_words_accuracy.py
Normal file
30
qurator/dinglehopper/metrics/bag_of_words_accuracy.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from collections import Counter
|
||||
from typing import Tuple, Union
|
||||
|
||||
from .utils import bag_accuracy, Weights
|
||||
from .. import ExtractedText
|
||||
from ..normalize import words_normalized
|
||||
|
||||
|
||||
def bag_of_words_accuracy(
|
||||
reference: Union[str, ExtractedText],
|
||||
compared: Union[str, ExtractedText],
|
||||
weights: Weights,
|
||||
) -> float:
|
||||
acc, _ = bag_of_words_accuracy_n(reference, compared, weights)
|
||||
return acc
|
||||
|
||||
|
||||
def bag_of_words_accuracy_n(
|
||||
reference: Union[str, ExtractedText],
|
||||
compared: Union[str, ExtractedText],
|
||||
weights: Weights,
|
||||
) -> Tuple[float, int]:
|
||||
if isinstance(reference, ExtractedText):
|
||||
reference = reference.text
|
||||
if isinstance(compared, ExtractedText):
|
||||
compared = compared.text
|
||||
reference_words = Counter(words_normalized(reference))
|
||||
compared_words = Counter(words_normalized(compared))
|
||||
e, n = bag_accuracy(reference_words, compared_words, weights)
|
||||
return (float("inf") if n == 0 else 1 - e / n), n
|
|
@ -1,13 +1,12 @@
|
|||
from __future__ import division
|
||||
|
||||
import unicodedata
|
||||
from typing import Tuple
|
||||
|
||||
from multimethod import multimethod
|
||||
from uniseg.graphemecluster import grapheme_clusters
|
||||
|
||||
from ..edit_distance import distance
|
||||
from .. import distance
|
||||
from ..extracted_text import ExtractedText
|
||||
from ..normalize import chars_normalized
|
||||
|
||||
|
||||
@multimethod
|
||||
|
@ -19,7 +18,7 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
|||
"""
|
||||
|
||||
d = distance(reference, compared)
|
||||
n = len(list(grapheme_clusters(unicodedata.normalize("NFC", reference))))
|
||||
n = len(chars_normalized(reference))
|
||||
|
||||
if d == 0:
|
||||
return 0, n
|
||||
|
|
41
qurator/dinglehopper/metrics/utils.py
Normal file
41
qurator/dinglehopper/metrics/utils.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
from collections import Counter
|
||||
from typing import NamedTuple, Tuple
|
||||
|
||||
|
||||
class Weights(NamedTuple):
|
||||
"""Represent weights/costs for editing operations."""
|
||||
|
||||
deletes: int = 1
|
||||
inserts: int = 1
|
||||
replacements: int = 1
|
||||
|
||||
|
||||
def bag_accuracy(
|
||||
reference: Counter, compared: Counter, weights: Weights
|
||||
) -> Tuple[int, int]:
|
||||
"""Calculates the the weighted errors for two bags (Counter).
|
||||
|
||||
Basic algorithm idea:
|
||||
- All elements in reference not occurring in compared are considered deletes.
|
||||
- All elements in compared not occurring in reference are considered inserts.
|
||||
- When the cost for one replacement is lower than that of one insert and one delete
|
||||
we can substitute pairs of deletes and inserts with one replacement.
|
||||
|
||||
:param reference: Bag used as reference (ground truth).
|
||||
:param compared: Bag used to compare (ocr).
|
||||
:param weights: Weights/costs for editing operations.
|
||||
:return: weighted errors and number of elements in reference.
|
||||
"""
|
||||
n = sum(reference.values())
|
||||
deletes = sum((reference - compared).values())
|
||||
inserts = sum((compared - reference).values())
|
||||
replacements = 0
|
||||
if weights.replacements < (weights.deletes + weights.inserts):
|
||||
replacements = min(deletes, inserts)
|
||||
deletes, inserts = max(deletes - inserts, 0), max(inserts - deletes, 0)
|
||||
weighted_errors = (
|
||||
weights.deletes * deletes
|
||||
+ weights.inserts * inserts
|
||||
+ weights.replacements * replacements
|
||||
)
|
||||
return weighted_errors, n
|
|
@ -1,65 +1,12 @@
|
|||
from __future__ import division
|
||||
|
||||
import unicodedata
|
||||
from typing import Tuple, Iterable
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
from multimethod import multimethod
|
||||
|
||||
import uniseg.wordbreak
|
||||
|
||||
from ..edit_distance import levenshtein
|
||||
from .. import ExtractedText
|
||||
|
||||
|
||||
@multimethod
|
||||
def words(s: str):
|
||||
"""Extract words from a string"""
|
||||
|
||||
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
|
||||
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
|
||||
old_word_break = uniseg.wordbreak.word_break
|
||||
|
||||
def new_word_break(c, index=0):
|
||||
if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area
|
||||
return "ALetter"
|
||||
else:
|
||||
return old_word_break(c, index)
|
||||
|
||||
uniseg.wordbreak.word_break = new_word_break
|
||||
|
||||
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
|
||||
def unwanted(c):
|
||||
|
||||
# See https://www.fileformat.info/info/unicode/category/index.htm
|
||||
# and https://unicodebook.readthedocs.io/unicode.html#categories
|
||||
unwanted_categories = "O", "M", "P", "Z", "S"
|
||||
unwanted_subcategories = "Cc", "Cf"
|
||||
|
||||
subcat = unicodedata.category(c)
|
||||
cat = subcat[0]
|
||||
return cat in unwanted_categories or subcat in unwanted_subcategories
|
||||
|
||||
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using
|
||||
# uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters."
|
||||
for word in uniseg.wordbreak.words(s):
|
||||
if all(unwanted(c) for c in word):
|
||||
pass
|
||||
else:
|
||||
yield word
|
||||
|
||||
|
||||
@multimethod
|
||||
def words(s: ExtractedText):
|
||||
return words(s.text)
|
||||
|
||||
|
||||
@multimethod
|
||||
def words_normalized(s: str):
|
||||
return words(unicodedata.normalize("NFC", s))
|
||||
|
||||
|
||||
@multimethod
|
||||
def words_normalized(s: ExtractedText):
|
||||
return words_normalized(s.text)
|
||||
from ..extracted_text import ExtractedText
|
||||
from ..normalize import words_normalized
|
||||
|
||||
|
||||
@multimethod
|
||||
|
|
61
qurator/dinglehopper/normalize.py
Normal file
61
qurator/dinglehopper/normalize.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
import unicodedata
|
||||
from typing import Union
|
||||
|
||||
import uniseg.wordbreak
|
||||
from uniseg.graphemecluster import grapheme_clusters
|
||||
|
||||
from .extracted_text import ExtractedText
|
||||
|
||||
|
||||
def chars_normalized(s: Union[str, ExtractedText]):
|
||||
"""Normalize characters in string."""
|
||||
if isinstance(s, ExtractedText):
|
||||
s = s.text
|
||||
return list(grapheme_clusters(unicodedata.normalize("NFC", s)))
|
||||
|
||||
|
||||
def words(s: Union[str, ExtractedText]):
|
||||
"""Extract words from a string"""
|
||||
|
||||
if isinstance(s, ExtractedText):
|
||||
s = s.text
|
||||
|
||||
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
|
||||
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
|
||||
old_word_break = uniseg.wordbreak.word_break
|
||||
|
||||
def new_word_break(c, index=0):
|
||||
if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area
|
||||
return "ALetter"
|
||||
else:
|
||||
return old_word_break(c, index)
|
||||
|
||||
uniseg.wordbreak.word_break = new_word_break
|
||||
|
||||
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
|
||||
def unwanted(c):
|
||||
|
||||
# See https://www.fileformat.info/info/unicode/category/index.htm
|
||||
# and https://unicodebook.readthedocs.io/unicode.html#categories
|
||||
unwanted_categories = "O", "M", "P", "Z", "S"
|
||||
unwanted_subcategories = "Cc", "Cf"
|
||||
|
||||
subcat = unicodedata.category(c)
|
||||
cat = subcat[0]
|
||||
return cat in unwanted_categories or subcat in unwanted_subcategories
|
||||
|
||||
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here:
|
||||
# Split on word boundaries using uniseg.wordbreak.words() and ignore all
|
||||
# "words" that contain only whitespace, punctuation "or similar characters."
|
||||
for word in uniseg.wordbreak.words(s):
|
||||
if all(unwanted(c) for c in word):
|
||||
pass
|
||||
else:
|
||||
yield word
|
||||
|
||||
|
||||
def words_normalized(s: Union[str, ExtractedText]):
|
||||
"""Extract words from string and normalize them."""
|
||||
if isinstance(s, ExtractedText):
|
||||
s = s.text
|
||||
return words(unicodedata.normalize("NFC", s))
|
104
qurator/dinglehopper/tests/metrics/test_bag_accuracy.py
Normal file
104
qurator/dinglehopper/tests/metrics/test_bag_accuracy.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import math
|
||||
import unicodedata
|
||||
from collections import Counter
|
||||
|
||||
import pytest
|
||||
|
||||
from ...metrics import bag_of_chars_accuracy_n, bag_of_words_accuracy_n, Weights
|
||||
from ...metrics.utils import bag_accuracy
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ex_weights():
|
||||
return (
|
||||
Weights(deletes=0, inserts=0, replacements=0),
|
||||
Weights(deletes=1, inserts=1, replacements=1),
|
||||
Weights(deletes=1, inserts=0, replacements=1),
|
||||
Weights(deletes=1, inserts=1, replacements=2),
|
||||
)
|
||||
|
||||
|
||||
SIMPLE_CASES = (
|
||||
("", "", 0, (0, 0, 0)),
|
||||
("abc", "", 3, (3, 3, 3)),
|
||||
("", "abc", 0, (3, 0, 3)),
|
||||
("abc", "abc", 3, (0, 0, 0)),
|
||||
("abc", "ab", 3, (1, 1, 1)),
|
||||
("abc", "abcd", 3, (1, 0, 1)),
|
||||
("abc", "abd", 3, (1, 1, 2)),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s1,s2, ex_n, ex_err",
|
||||
[
|
||||
*SIMPLE_CASES,
|
||||
(("a", "b", "c", "d", "e"), ("a", "b", "c", "d", ("e", "´")), 5, (1, 1, 2)),
|
||||
(range(5), range(6), 5, (1, 0, 1)),
|
||||
],
|
||||
)
|
||||
def test_bag_accuracy_algorithm(s1, s2, ex_n, ex_err, ex_weights):
|
||||
"""Test the main algorithm for calculating the bag accuracy."""
|
||||
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
|
||||
e, n = bag_accuracy(Counter(s1), Counter(s2), weights=weights)
|
||||
assert n == ex_n, f"{n} == {ex_n} for {weights}"
|
||||
assert e == expected_errors, f"{e} == {expected_errors} for {weights}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s1,s2, ex_n, ex_err",
|
||||
[
|
||||
*SIMPLE_CASES,
|
||||
("Schlyñ", "Schlym̃", 6, (1, 1, 2)),
|
||||
(
|
||||
unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
|
||||
unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
|
||||
19,
|
||||
(1, 1, 2),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_bag_of_chars_accuracy_n(s1, s2, ex_n, ex_err, ex_weights):
|
||||
"""Test the special behaviour of the char differentiation.
|
||||
|
||||
As the algorithm and the char normalization is implemented elsewhere
|
||||
we are currently only testing that the corresponding algorithms are called.
|
||||
"""
|
||||
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
|
||||
acc, n = bag_of_chars_accuracy_n(s1, s2, weights)
|
||||
assert n == ex_n, f"{n} == {ex_n} for {weights}"
|
||||
if ex_n == 0:
|
||||
assert math.isinf(acc)
|
||||
else:
|
||||
assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s1,s2, ex_n, ex_err",
|
||||
[
|
||||
*SIMPLE_CASES,
|
||||
("Schlyñ", "Schlym̃", 6, (1, 1, 2)),
|
||||
(
|
||||
unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
|
||||
unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
|
||||
3,
|
||||
(0, 0, 0),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_bag_of_words_accuracy_n(s1, s2, ex_n, ex_err, ex_weights):
|
||||
"""Test the special behaviour of the word differentiation.
|
||||
|
||||
As the algorithm and the word splitting is implemented elsewhere
|
||||
we are currently only testing that the corresponding algorithms are called.
|
||||
"""
|
||||
if " " not in s1 and " " not in s2:
|
||||
s1 = " ".join(s1)
|
||||
s2 = " ".join(s2)
|
||||
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
|
||||
acc, n = bag_of_words_accuracy_n(s1, s2, weights)
|
||||
assert n == ex_n, f"{n} == {ex_n} for {weights}"
|
||||
if ex_n == 0:
|
||||
assert math.isinf(acc)
|
||||
else:
|
||||
assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}"
|
|
@ -5,8 +5,9 @@ import os
|
|||
import pytest
|
||||
from lxml import etree as ET
|
||||
|
||||
from ... import page_text, alto_text
|
||||
from ...metrics import word_error_rate, words\
|
||||
from ... import alto_text, page_text
|
||||
from ...metrics import word_error_rate
|
||||
from ...normalize import words
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
|
||||
|
||||
|
|
|
@ -2,7 +2,8 @@ from __future__ import division, print_function
|
|||
|
||||
import math
|
||||
|
||||
from ...metrics import word_error_rate, words
|
||||
from ...metrics import word_error_rate
|
||||
from ...normalize import words
|
||||
|
||||
|
||||
def test_words():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue