Add BoC and BoW metric
Also some refactoring for helper methods on normalization and word splitting.pull/60/head
parent
4ccae9432d
commit
8cd624f795
@ -1,2 +1,5 @@
|
|||||||
|
from .bag_of_chars_accuracy import *
|
||||||
|
from .bag_of_words_accuracy import *
|
||||||
from .character_error_rate import *
|
from .character_error_rate import *
|
||||||
|
from .utils import Weights
|
||||||
from .word_error_rate import *
|
from .word_error_rate import *
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
from collections import Counter
|
||||||
|
from typing import Tuple, Union
|
||||||
|
from unicodedata import normalize
|
||||||
|
|
||||||
|
from multimethod import multimethod
|
||||||
|
from uniseg.graphemecluster import grapheme_clusters
|
||||||
|
|
||||||
|
from .utils import bag_accuracy, Weights
|
||||||
|
from .. import ExtractedText
|
||||||
|
|
||||||
|
|
||||||
|
def bag_of_chars_accuracy(
|
||||||
|
reference: Union[str, ExtractedText],
|
||||||
|
compared: Union[str, ExtractedText],
|
||||||
|
weights: Weights,
|
||||||
|
) -> float:
|
||||||
|
acc, _ = bag_of_chars_accuracy_n(reference, compared, weights)
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
@multimethod
|
||||||
|
def bag_of_chars_accuracy_n(
|
||||||
|
reference: str, compared: str, weights: Weights
|
||||||
|
) -> Tuple[float, int]:
|
||||||
|
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
|
||||||
|
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
|
||||||
|
e, n = bag_accuracy(reference_chars, compared_chars, weights)
|
||||||
|
return (float("inf") if n == 0 else 1 - e / n), n
|
||||||
|
|
||||||
|
|
||||||
|
@multimethod
|
||||||
|
def bag_of_chars_accuracy_n(
|
||||||
|
reference: ExtractedText, compared: ExtractedText, weights: Weights
|
||||||
|
) -> Tuple[float, int]:
|
||||||
|
return bag_of_chars_accuracy_n(reference.text, compared.text, weights)
|
@ -0,0 +1,30 @@
|
|||||||
|
from collections import Counter
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
from .utils import bag_accuracy, Weights
|
||||||
|
from .. import ExtractedText
|
||||||
|
from ..normalize import words_normalized
|
||||||
|
|
||||||
|
|
||||||
|
def bag_of_words_accuracy(
|
||||||
|
reference: Union[str, ExtractedText],
|
||||||
|
compared: Union[str, ExtractedText],
|
||||||
|
weights: Weights,
|
||||||
|
) -> float:
|
||||||
|
acc, _ = bag_of_words_accuracy_n(reference, compared, weights)
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
def bag_of_words_accuracy_n(
|
||||||
|
reference: Union[str, ExtractedText],
|
||||||
|
compared: Union[str, ExtractedText],
|
||||||
|
weights: Weights,
|
||||||
|
) -> Tuple[float, int]:
|
||||||
|
if isinstance(reference, ExtractedText):
|
||||||
|
reference = reference.text
|
||||||
|
if isinstance(compared, ExtractedText):
|
||||||
|
compared = compared.text
|
||||||
|
reference_words = Counter(words_normalized(reference))
|
||||||
|
compared_words = Counter(words_normalized(compared))
|
||||||
|
e, n = bag_accuracy(reference_words, compared_words, weights)
|
||||||
|
return (float("inf") if n == 0 else 1 - e / n), n
|
@ -0,0 +1,41 @@
|
|||||||
|
from collections import Counter
|
||||||
|
from typing import NamedTuple, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class Weights(NamedTuple):
|
||||||
|
"""Represent weights/costs for editing operations."""
|
||||||
|
|
||||||
|
deletes: int = 1
|
||||||
|
inserts: int = 1
|
||||||
|
replacements: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
def bag_accuracy(
|
||||||
|
reference: Counter, compared: Counter, weights: Weights
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""Calculates the the weighted errors for two bags (Counter).
|
||||||
|
|
||||||
|
Basic algorithm idea:
|
||||||
|
- All elements in reference not occurring in compared are considered deletes.
|
||||||
|
- All elements in compared not occurring in reference are considered inserts.
|
||||||
|
- When the cost for one replacement is lower than that of one insert and one delete
|
||||||
|
we can substitute pairs of deletes and inserts with one replacement.
|
||||||
|
|
||||||
|
:param reference: Bag used as reference (ground truth).
|
||||||
|
:param compared: Bag used to compare (ocr).
|
||||||
|
:param weights: Weights/costs for editing operations.
|
||||||
|
:return: weighted errors and number of elements in reference.
|
||||||
|
"""
|
||||||
|
n = sum(reference.values())
|
||||||
|
deletes = sum((reference - compared).values())
|
||||||
|
inserts = sum((compared - reference).values())
|
||||||
|
replacements = 0
|
||||||
|
if weights.replacements < (weights.deletes + weights.inserts):
|
||||||
|
replacements = min(deletes, inserts)
|
||||||
|
deletes, inserts = max(deletes - inserts, 0), max(inserts - deletes, 0)
|
||||||
|
weighted_errors = (
|
||||||
|
weights.deletes * deletes
|
||||||
|
+ weights.inserts * inserts
|
||||||
|
+ weights.replacements * replacements
|
||||||
|
)
|
||||||
|
return weighted_errors, n
|
@ -0,0 +1,61 @@
|
|||||||
|
import unicodedata
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import uniseg.wordbreak
|
||||||
|
from uniseg.graphemecluster import grapheme_clusters
|
||||||
|
|
||||||
|
from .extracted_text import ExtractedText
|
||||||
|
|
||||||
|
|
||||||
|
def chars_normalized(s: Union[str, ExtractedText]):
|
||||||
|
"""Normalize characters in string."""
|
||||||
|
if isinstance(s, ExtractedText):
|
||||||
|
s = s.text
|
||||||
|
return list(grapheme_clusters(unicodedata.normalize("NFC", s)))
|
||||||
|
|
||||||
|
|
||||||
|
def words(s: Union[str, ExtractedText]):
|
||||||
|
"""Extract words from a string"""
|
||||||
|
|
||||||
|
if isinstance(s, ExtractedText):
|
||||||
|
s = s.text
|
||||||
|
|
||||||
|
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
|
||||||
|
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
|
||||||
|
old_word_break = uniseg.wordbreak.word_break
|
||||||
|
|
||||||
|
def new_word_break(c, index=0):
|
||||||
|
if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area
|
||||||
|
return "ALetter"
|
||||||
|
else:
|
||||||
|
return old_word_break(c, index)
|
||||||
|
|
||||||
|
uniseg.wordbreak.word_break = new_word_break
|
||||||
|
|
||||||
|
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
|
||||||
|
def unwanted(c):
|
||||||
|
|
||||||
|
# See https://www.fileformat.info/info/unicode/category/index.htm
|
||||||
|
# and https://unicodebook.readthedocs.io/unicode.html#categories
|
||||||
|
unwanted_categories = "O", "M", "P", "Z", "S"
|
||||||
|
unwanted_subcategories = "Cc", "Cf"
|
||||||
|
|
||||||
|
subcat = unicodedata.category(c)
|
||||||
|
cat = subcat[0]
|
||||||
|
return cat in unwanted_categories or subcat in unwanted_subcategories
|
||||||
|
|
||||||
|
# We follow Unicode Standard Annex #29 on Unicode Text Segmentation here:
|
||||||
|
# Split on word boundaries using uniseg.wordbreak.words() and ignore all
|
||||||
|
# "words" that contain only whitespace, punctuation "or similar characters."
|
||||||
|
for word in uniseg.wordbreak.words(s):
|
||||||
|
if all(unwanted(c) for c in word):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
yield word
|
||||||
|
|
||||||
|
|
||||||
|
def words_normalized(s: Union[str, ExtractedText]):
|
||||||
|
"""Extract words from string and normalize them."""
|
||||||
|
if isinstance(s, ExtractedText):
|
||||||
|
s = s.text
|
||||||
|
return words(unicodedata.normalize("NFC", s))
|
Loading…
Reference in New Issue