Add BoC and BoW metric

Also some refactoring for helper methods on normalization and word
splitting.
pull/60/head
Benjamin Rosemann 4 years ago
parent 4ccae9432d
commit 8cd624f795

@ -1,10 +1,11 @@
from .edit_distance import * from .edit_distance import *
from .normalize import chars_normalized
def align(t1, t2): def align(t1, t2):
"""Align text.""" """Align text."""
s1 = list(grapheme_clusters(unicodedata.normalize("NFC", t1))) s1 = chars_normalized(t1)
s2 = list(grapheme_clusters(unicodedata.normalize("NFC", t2))) s2 = chars_normalized(t2)
return seq_align(s1, s2) return seq_align(s1, s2)

@ -1,16 +1,15 @@
from __future__ import division, print_function from __future__ import division, print_function
import unicodedata from functools import lru_cache, partial
from functools import partial, lru_cache
from typing import Sequence, Tuple from typing import Sequence, Tuple
import numpy as np import numpy as np
from multimethod import multimethod from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters
from tqdm import tqdm from tqdm import tqdm
from .extracted_text import ExtractedText
from .config import Config from .config import Config
from .extracted_text import ExtractedText
from .normalize import chars_normalized
def levenshtein_matrix(seq1: Sequence, seq2: Sequence): def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
@ -82,8 +81,8 @@ def distance(s1: str, s2: str):
Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme
clusters. This should be the correct way to compare two Unicode strings. clusters. This should be the correct way to compare two Unicode strings.
""" """
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", s1))) seq1 = chars_normalized(s1)
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", s2))) seq2 = chars_normalized(s2)
return levenshtein(seq1, seq2) return levenshtein(seq1, seq2)
@ -139,6 +138,6 @@ def editops(word1, word2):
Note that this returns indices to the _grapheme clusters_, not characters! Note that this returns indices to the _grapheme clusters_, not characters!
""" """
word1 = list(grapheme_clusters(unicodedata.normalize("NFC", word1))) word1 = chars_normalized(word1)
word2 = list(grapheme_clusters(unicodedata.normalize("NFC", word2))) word2 = chars_normalized(word2)
return seq_editops(word1, word2) return seq_editops(word1, word2)

@ -1,2 +1,5 @@
from .bag_of_chars_accuracy import *
from .bag_of_words_accuracy import *
from .character_error_rate import * from .character_error_rate import *
from .utils import Weights
from .word_error_rate import * from .word_error_rate import *

@ -0,0 +1,35 @@
from collections import Counter
from typing import Tuple, Union
from unicodedata import normalize
from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters
from .utils import bag_accuracy, Weights
from .. import ExtractedText
def bag_of_chars_accuracy(
reference: Union[str, ExtractedText],
compared: Union[str, ExtractedText],
weights: Weights,
) -> float:
acc, _ = bag_of_chars_accuracy_n(reference, compared, weights)
return acc
@multimethod
def bag_of_chars_accuracy_n(
reference: str, compared: str, weights: Weights
) -> Tuple[float, int]:
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
e, n = bag_accuracy(reference_chars, compared_chars, weights)
return (float("inf") if n == 0 else 1 - e / n), n
@multimethod
def bag_of_chars_accuracy_n(
reference: ExtractedText, compared: ExtractedText, weights: Weights
) -> Tuple[float, int]:
return bag_of_chars_accuracy_n(reference.text, compared.text, weights)

@ -0,0 +1,30 @@
from collections import Counter
from typing import Tuple, Union
from .utils import bag_accuracy, Weights
from .. import ExtractedText
from ..normalize import words_normalized
def bag_of_words_accuracy(
reference: Union[str, ExtractedText],
compared: Union[str, ExtractedText],
weights: Weights,
) -> float:
acc, _ = bag_of_words_accuracy_n(reference, compared, weights)
return acc
def bag_of_words_accuracy_n(
reference: Union[str, ExtractedText],
compared: Union[str, ExtractedText],
weights: Weights,
) -> Tuple[float, int]:
if isinstance(reference, ExtractedText):
reference = reference.text
if isinstance(compared, ExtractedText):
compared = compared.text
reference_words = Counter(words_normalized(reference))
compared_words = Counter(words_normalized(compared))
e, n = bag_accuracy(reference_words, compared_words, weights)
return (float("inf") if n == 0 else 1 - e / n), n

@ -1,13 +1,12 @@
from __future__ import division from __future__ import division
import unicodedata
from typing import Tuple from typing import Tuple
from multimethod import multimethod from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters
from ..edit_distance import distance from .. import distance
from ..extracted_text import ExtractedText from ..extracted_text import ExtractedText
from ..normalize import chars_normalized
@multimethod @multimethod
@ -19,7 +18,7 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
""" """
d = distance(reference, compared) d = distance(reference, compared)
n = len(list(grapheme_clusters(unicodedata.normalize("NFC", reference)))) n = len(chars_normalized(reference))
if d == 0: if d == 0:
return 0, n return 0, n

@ -0,0 +1,41 @@
from collections import Counter
from typing import NamedTuple, Tuple
class Weights(NamedTuple):
"""Represent weights/costs for editing operations."""
deletes: int = 1
inserts: int = 1
replacements: int = 1
def bag_accuracy(
reference: Counter, compared: Counter, weights: Weights
) -> Tuple[int, int]:
"""Calculates the the weighted errors for two bags (Counter).
Basic algorithm idea:
- All elements in reference not occurring in compared are considered deletes.
- All elements in compared not occurring in reference are considered inserts.
- When the cost for one replacement is lower than that of one insert and one delete
we can substitute pairs of deletes and inserts with one replacement.
:param reference: Bag used as reference (ground truth).
:param compared: Bag used to compare (ocr).
:param weights: Weights/costs for editing operations.
:return: weighted errors and number of elements in reference.
"""
n = sum(reference.values())
deletes = sum((reference - compared).values())
inserts = sum((compared - reference).values())
replacements = 0
if weights.replacements < (weights.deletes + weights.inserts):
replacements = min(deletes, inserts)
deletes, inserts = max(deletes - inserts, 0), max(inserts - deletes, 0)
weighted_errors = (
weights.deletes * deletes
+ weights.inserts * inserts
+ weights.replacements * replacements
)
return weighted_errors, n

@ -1,65 +1,12 @@
from __future__ import division from __future__ import division
import unicodedata from typing import Iterable, Tuple
from typing import Tuple, Iterable
from multimethod import multimethod
import uniseg.wordbreak from multimethod import multimethod
from ..edit_distance import levenshtein from ..edit_distance import levenshtein
from .. import ExtractedText from ..extracted_text import ExtractedText
from ..normalize import words_normalized
@multimethod
def words(s: str):
"""Extract words from a string"""
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
old_word_break = uniseg.wordbreak.word_break
def new_word_break(c, index=0):
if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area
return "ALetter"
else:
return old_word_break(c, index)
uniseg.wordbreak.word_break = new_word_break
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
def unwanted(c):
# See https://www.fileformat.info/info/unicode/category/index.htm
# and https://unicodebook.readthedocs.io/unicode.html#categories
unwanted_categories = "O", "M", "P", "Z", "S"
unwanted_subcategories = "Cc", "Cf"
subcat = unicodedata.category(c)
cat = subcat[0]
return cat in unwanted_categories or subcat in unwanted_subcategories
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using
# uniseg.wordbreak.words() and ignore all "words" that contain only whitespace, punctation "or similar characters."
for word in uniseg.wordbreak.words(s):
if all(unwanted(c) for c in word):
pass
else:
yield word
@multimethod
def words(s: ExtractedText):
return words(s.text)
@multimethod
def words_normalized(s: str):
return words(unicodedata.normalize("NFC", s))
@multimethod
def words_normalized(s: ExtractedText):
return words_normalized(s.text)
@multimethod @multimethod

@ -0,0 +1,61 @@
import unicodedata
from typing import Union
import uniseg.wordbreak
from uniseg.graphemecluster import grapheme_clusters
from .extracted_text import ExtractedText
def chars_normalized(s: Union[str, ExtractedText]):
"""Normalize characters in string."""
if isinstance(s, ExtractedText):
s = s.text
return list(grapheme_clusters(unicodedata.normalize("NFC", s)))
def words(s: Union[str, ExtractedText]):
"""Extract words from a string"""
if isinstance(s, ExtractedText):
s = s.text
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
old_word_break = uniseg.wordbreak.word_break
def new_word_break(c, index=0):
if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area
return "ALetter"
else:
return old_word_break(c, index)
uniseg.wordbreak.word_break = new_word_break
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
def unwanted(c):
# See https://www.fileformat.info/info/unicode/category/index.htm
# and https://unicodebook.readthedocs.io/unicode.html#categories
unwanted_categories = "O", "M", "P", "Z", "S"
unwanted_subcategories = "Cc", "Cf"
subcat = unicodedata.category(c)
cat = subcat[0]
return cat in unwanted_categories or subcat in unwanted_subcategories
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here:
# Split on word boundaries using uniseg.wordbreak.words() and ignore all
# "words" that contain only whitespace, punctuation "or similar characters."
for word in uniseg.wordbreak.words(s):
if all(unwanted(c) for c in word):
pass
else:
yield word
def words_normalized(s: Union[str, ExtractedText]):
"""Extract words from string and normalize them."""
if isinstance(s, ExtractedText):
s = s.text
return words(unicodedata.normalize("NFC", s))

@ -0,0 +1,104 @@
import math
import unicodedata
from collections import Counter
import pytest
from ...metrics import bag_of_chars_accuracy_n, bag_of_words_accuracy_n, Weights
from ...metrics.utils import bag_accuracy
@pytest.fixture
def ex_weights():
return (
Weights(deletes=0, inserts=0, replacements=0),
Weights(deletes=1, inserts=1, replacements=1),
Weights(deletes=1, inserts=0, replacements=1),
Weights(deletes=1, inserts=1, replacements=2),
)
SIMPLE_CASES = (
("", "", 0, (0, 0, 0)),
("abc", "", 3, (3, 3, 3)),
("", "abc", 0, (3, 0, 3)),
("abc", "abc", 3, (0, 0, 0)),
("abc", "ab", 3, (1, 1, 1)),
("abc", "abcd", 3, (1, 0, 1)),
("abc", "abd", 3, (1, 1, 2)),
)
@pytest.mark.parametrize(
"s1,s2, ex_n, ex_err",
[
*SIMPLE_CASES,
(("a", "b", "c", "d", "e"), ("a", "b", "c", "d", ("e", "´")), 5, (1, 1, 2)),
(range(5), range(6), 5, (1, 0, 1)),
],
)
def test_bag_accuracy_algorithm(s1, s2, ex_n, ex_err, ex_weights):
"""Test the main algorithm for calculating the bag accuracy."""
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
e, n = bag_accuracy(Counter(s1), Counter(s2), weights=weights)
assert n == ex_n, f"{n} == {ex_n} for {weights}"
assert e == expected_errors, f"{e} == {expected_errors} for {weights}"
@pytest.mark.parametrize(
"s1,s2, ex_n, ex_err",
[
*SIMPLE_CASES,
("Schlyñ", "Schlym̃", 6, (1, 1, 2)),
(
unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
19,
(1, 1, 2),
),
],
)
def test_bag_of_chars_accuracy_n(s1, s2, ex_n, ex_err, ex_weights):
"""Test the special behaviour of the char differentiation.
As the algorithm and the char normalization is implemented elsewhere
we are currently only testing that the corresponding algorithms are called.
"""
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
acc, n = bag_of_chars_accuracy_n(s1, s2, weights)
assert n == ex_n, f"{n} == {ex_n} for {weights}"
if ex_n == 0:
assert math.isinf(acc)
else:
assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}"
@pytest.mark.parametrize(
"s1,s2, ex_n, ex_err",
[
*SIMPLE_CASES,
("Schlyñ", "Schlym̃", 6, (1, 1, 2)),
(
unicodedata.normalize("NFC", "Schlyñ lorem ipsum."),
unicodedata.normalize("NFD", "Schlyñ lorem ipsum!"),
3,
(0, 0, 0),
),
],
)
def test_bag_of_words_accuracy_n(s1, s2, ex_n, ex_err, ex_weights):
"""Test the special behaviour of the word differentiation.
As the algorithm and the word splitting is implemented elsewhere
we are currently only testing that the corresponding algorithms are called.
"""
if " " not in s1 and " " not in s2:
s1 = " ".join(s1)
s2 = " ".join(s2)
for weights, expected_errors in zip(ex_weights, (0, *ex_err)):
acc, n = bag_of_words_accuracy_n(s1, s2, weights)
assert n == ex_n, f"{n} == {ex_n} for {weights}"
if ex_n == 0:
assert math.isinf(acc)
else:
assert acc == pytest.approx(1 - expected_errors / ex_n), f"w: {weights}"

@ -5,8 +5,9 @@ import os
import pytest import pytest
from lxml import etree as ET from lxml import etree as ET
from ... import page_text, alto_text from ... import alto_text, page_text
from ...metrics import word_error_rate, words\ from ...metrics import word_error_rate
from ...normalize import words
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data") data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")

@ -2,7 +2,8 @@ from __future__ import division, print_function
import math import math
from ...metrics import word_error_rate, words from ...metrics import word_error_rate
from ...normalize import words
def test_words(): def test_words():

Loading…
Cancel
Save