Fixed some flake8 and mypy issues.

pull/60/head
Benjamin Rosemann 4 years ago
parent a44a3d4bf2
commit 714b569195

@ -1,3 +0,0 @@
from .ocr_files import *
from .extracted_text import *
from .align import *

@ -1,4 +1,4 @@
from .edit_distance import * from .edit_distance import seq_editops
from .normalize import chars_normalized from .normalize import chars_normalized

@ -36,13 +36,9 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_) html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_)
if css_classes: if css_classes:
return '<span class="{css_classes}" {html_custom_attrs}>{html_t}</span>'.format( return f"<span class=\"{css_classes}\" {html_custom_attrs}>{html_t}</span>"
css_classes=css_classes,
html_t=html_t,
html_custom_attrs=html_custom_attrs,
)
else: else:
return "{html_t}".format(html_t=html_t) return f"{html_t}"
if isinstance(gt_in, ExtractedText): if isinstance(gt_in, ExtractedText):
if not isinstance(ocr_in, ExtractedText): if not isinstance(ocr_in, ExtractedText):

@ -1,9 +1,6 @@
import os
import click import click
from ocrd_utils import initLogging from ocrd_utils import initLogging
from .extracted_text import ExtractedText
from .ocr_files import extract from .ocr_files import extract

@ -12,15 +12,16 @@ from .normalize import chars_normalized
def levenshtein_matrix(seq1: Sequence, seq2: Sequence): def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
"""Compute the matrix commonly computed to produce the Levenshtein distance. """Compute the matrix commonly computed to produce the Levenshtein distance.
This is also known as the Wagner-Fischer algorithm. The matrix element at the bottom right contains the desired This is also known as the Wagner-Fischer algorithm. The matrix element at the bottom
edit distance. right contains the desired edit distance.
This algorithm is implemented here because we need an implementation that can work with sequences other than This algorithm is implemented here because we need an implementation that can work
strings, e.g. lists of grapheme clusters or lists of word strings. with sequences other than strings, e.g. lists of grapheme clusters or lists of word
strings.
""" """
# Internally, we use a cached version. As the cache only works on hashable parameters, we convert the input # Internally, we use a cached version. As the cache only works on hashable
# sequences to tuples to make them hashable. # parameters, we convert the input sequences to tuples to make them hashable.
return _levenshtein_matrix(tuple(seq1), tuple(seq2)) return _levenshtein_matrix(tuple(seq1), tuple(seq2))
@ -28,7 +29,8 @@ def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
def _levenshtein_matrix(seq1: Tuple, seq2: Tuple): def _levenshtein_matrix(seq1: Tuple, seq2: Tuple):
"""Compute the matrix commonly computed to produce the Levenshtein distance. """Compute the matrix commonly computed to produce the Levenshtein distance.
This is a LRU cached function not meant to be used directly. Use levenshtein_matrix() instead. This is a LRU cached function not meant to be used directly.
Use levenshtein_matrix() instead.
""" """
m = len(seq1) m = len(seq1)
n = len(seq2) n = len(seq2)
@ -36,7 +38,7 @@ def _levenshtein_matrix(seq1: Tuple, seq2: Tuple):
def from_to(start, stop): def from_to(start, stop):
return range(start, stop + 1, 1) return range(start, stop + 1, 1)
D = np.zeros((m + 1, n + 1), np.int) D = np.zeros((m + 1, n + 1), int)
D[0, 0] = 0 D[0, 0] = 0
for i in from_to(1, m): for i in from_to(1, m):
D[i, 0] = i D[i, 0] = i
@ -75,8 +77,9 @@ def levenshtein_matrix_cache_clear():
def distance(s1: str, s2: str): def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings """Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme Note that this is different from levenshtein() as this function knows about Unicode
clusters. This should be the correct way to compare two Unicode strings. normalization and grapheme clusters. This should be the correct way to compare two
Unicode strings.
""" """
seq1 = chars_normalized(s1) seq1 = chars_normalized(s1)
seq2 = chars_normalized(s2) seq2 = chars_normalized(s2)
@ -87,8 +90,8 @@ def seq_editops(seq1, seq2):
""" """
Return sequence of edit operations transforming one sequence to another. Return sequence of edit operations transforming one sequence to another.
This aims to return the same/similar results as python-Levenshtein's editops(), just generalized to arbitrary This aims to return the same/similar results as python-Levenshtein's editops(),
sequences. just generalized to arbitrary sequences.
""" """
seq1 = list(seq1) seq1 = list(seq1)
seq2 = list(seq2) seq2 = list(seq2)

@ -1,5 +1,14 @@
from .bag_of_chars_accuracy import * from .bag_of_chars_accuracy import bag_of_chars_accuracy
from .bag_of_words_accuracy import * from .bag_of_words_accuracy import bag_of_words_accuracy
from .character_accuracy import * from .character_accuracy import character_accuracy
from .utils import MetricResult, Weights from .utils import MetricResult, Weights
from .word_accuracy import * from .word_accuracy import word_accuracy
__all__ = [
"bag_of_chars_accuracy",
"bag_of_words_accuracy",
"character_accuracy",
"word_accuracy",
"MetricResult",
"Weights",
]

@ -9,9 +9,8 @@ from .utils import bag_accuracy, MetricResult, Weights
def bag_of_chars_accuracy( def bag_of_chars_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 0, 1) reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult: ) -> MetricResult:
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference))) reference_chars: Counter = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared))) compared_chars: Counter = Counter(grapheme_clusters(normalize("NFC", compared)))
result = bag_accuracy(reference_chars, compared_chars, weights) return bag_accuracy(
return MetricResult( reference_chars, compared_chars, weights, bag_of_chars_accuracy.__name__
**{**result._asdict(), "metric": bag_of_chars_accuracy.__name__}
) )

@ -7,9 +7,8 @@ from ..normalize import words_normalized
def bag_of_words_accuracy( def bag_of_words_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 0, 1) reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult: ) -> MetricResult:
reference_words = Counter(words_normalized(reference)) reference_words: Counter = Counter(words_normalized(reference))
compared_words = Counter(words_normalized(compared)) compared_words: Counter = Counter(words_normalized(compared))
result = bag_accuracy(reference_words, compared_words, weights) return bag_accuracy(
return MetricResult( reference_words, compared_words, weights, bag_of_words_accuracy.__name__
**{**result._asdict(), "metric": bag_of_words_accuracy.__name__}
) )

@ -1,7 +1,5 @@
from __future__ import division
from .utils import MetricResult, Weights from .utils import MetricResult, Weights
from .. import distance from ..edit_distance import distance
from ..normalize import chars_normalized from ..normalize import chars_normalized

@ -37,10 +37,7 @@ class MetricResult(NamedTuple):
We deviate from the builtin _asdict() function by including our properties. We deviate from the builtin _asdict() function by including our properties.
""" """
return { return {
**{ **{key: value for key, value in self._asdict().items()},
key: value
for key, value in self._asdict().items()
},
"accuracy": self.accuracy, "accuracy": self.accuracy,
"error_rate": self.error_rate, "error_rate": self.error_rate,
"weights": self.weights._asdict(), "weights": self.weights._asdict(),
@ -48,7 +45,10 @@ class MetricResult(NamedTuple):
def bag_accuracy( def bag_accuracy(
reference: Counter, compared: Counter, weights: Weights reference: Counter,
compared: Counter,
weights: Weights,
metric: str = "bag_accuracy",
) -> MetricResult: ) -> MetricResult:
"""Calculates the the weighted errors for two bags (Counter). """Calculates the the weighted errors for two bags (Counter).
@ -61,6 +61,7 @@ def bag_accuracy(
:param reference: Bag used as reference (ground truth). :param reference: Bag used as reference (ground truth).
:param compared: Bag used to compare (ocr). :param compared: Bag used to compare (ocr).
:param weights: Weights/costs for editing operations. :param weights: Weights/costs for editing operations.
:param metric: Name of the (original) metric.
:return: NamedTuple representing the results of this metric. :return: NamedTuple representing the results of this metric.
""" """
n_ref = sum(reference.values()) n_ref = sum(reference.values())
@ -77,7 +78,7 @@ def bag_accuracy(
+ weights.replacements * replacements + weights.replacements * replacements
) )
return MetricResult( return MetricResult(
metric=bag_accuracy.__name__, metric=metric,
weights=weights, weights=weights,
weighted_errors=weighted_errors, weighted_errors=weighted_errors,
reference_elements=n_ref, reference_elements=n_ref,

@ -1,26 +1,19 @@
import unicodedata import unicodedata
from typing import Union
import uniseg.wordbreak import uniseg.wordbreak
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
from .extracted_text import ExtractedText
def chars_normalized(s: str):
def chars_normalized(s: Union[str, ExtractedText]):
"""Normalize characters in string.""" """Normalize characters in string."""
if isinstance(s, ExtractedText):
s = s.text
return list(grapheme_clusters(unicodedata.normalize("NFC", s))) return list(grapheme_clusters(unicodedata.normalize("NFC", s)))
def words(s: Union[str, ExtractedText]): def words(s: str):
"""Extract words from a string""" """Extract words from a string"""
if isinstance(s, ExtractedText): # Patch uniseg.wordbreak.word_break to deal with our private use characters.
s = s.text # See also
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
old_word_break = uniseg.wordbreak.word_break old_word_break = uniseg.wordbreak.word_break
@ -54,8 +47,6 @@ def words(s: Union[str, ExtractedText]):
yield word yield word
def words_normalized(s: Union[str, ExtractedText]): def words_normalized(s: str):
"""Extract words from string and normalize them.""" """Extract words from string and normalize them."""
if isinstance(s, ExtractedText):
s = s.text
return words(unicodedata.normalize("NFC", s)) return words(unicodedata.normalize("NFC", s))

@ -1,7 +1,4 @@
from __future__ import division, print_function
from typing import Iterator from typing import Iterator
from warnings import warn
import sys import sys
from lxml import etree as ET from lxml import etree as ET
@ -13,7 +10,8 @@ from .extracted_text import ExtractedText, normalize_sbb
def alto_namespace(tree: ET.ElementTree) -> str: def alto_namespace(tree: ET.ElementTree) -> str:
"""Return the ALTO namespace used in the given ElementTree. """Return the ALTO namespace used in the given ElementTree.
This relies on the assumption that, in any given ALTO file, the root element has the local name "alto". We do not This relies on the assumption that, in any given ALTO file,
the root element has the local name "alto". We do not
check if the files uses any valid ALTO namespace. check if the files uses any valid ALTO namespace.
""" """
root_name = ET.QName(tree.getroot().tag) root_name = ET.QName(tree.getroot().tag)
@ -47,8 +45,9 @@ def alto_text(tree):
def page_namespace(tree): def page_namespace(tree):
"""Return the PAGE content namespace used in the given ElementTree. """Return the PAGE content namespace used in the given ElementTree.
This relies on the assumption that, in any given PAGE content file, the root element has the local name "PcGts". We This relies on the assumption that, in any given PAGE content file,
do not check if the files uses any valid PAGE namespace. the root element has the local name "PcGts".
We do not check if the files uses any valid PAGE namespace.
""" """
root_name = ET.QName(tree.getroot().tag) root_name = ET.QName(tree.getroot().tag)
if root_name.localname == "PcGts": if root_name.localname == "PcGts":
@ -97,14 +96,18 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children) ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children)
ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"])) ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"]))
elif ET.QName(group.tag).localname in ["UnorderedGroup","UnorderedGroupIndexed"]: elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]:
ro_children = list(group) ro_children = list(group)
else: else:
raise NotImplementedError raise NotImplementedError
for ro_child in ro_children: for ro_child in ro_children:
if ET.QName(ro_child.tag).localname in ["OrderedGroup", "OrderedGroupIndexed", "UnorderedGroup", "UnorderedGroupIndexed"]: if ET.QName(ro_child.tag).localname in [
"OrderedGroup",
"OrderedGroupIndexed",
"UnorderedGroup",
"UnorderedGroupIndexed",
]:
regions.extend( regions.extend(
extract_texts_from_reading_order_group( extract_texts_from_reading_order_group(
ro_child, tree, nsmap, textequiv_level ro_child, tree, nsmap, textequiv_level

@ -77,5 +77,6 @@ class OcrdDinglehopperEvaluate(Processor):
# Clear cache between files # Clear cache between files
levenshtein_matrix_cache_clear() levenshtein_matrix_cache_clear()
if __name__ == "__main__": if __name__ == "__main__":
ocrd_dinglehopper() ocrd_dinglehopper()

@ -6,7 +6,8 @@ import pytest
from lxml import etree as ET from lxml import etree as ET
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
from .. import seq_align, ExtractedText from ..align import seq_align
from ..extracted_text import ExtractedText
def test_text(): def test_text():

@ -6,8 +6,8 @@ import pytest
from lxml import etree as ET from lxml import etree as ET
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
from ... import page_text, alto_text
from ...metrics import character_accuracy from ...metrics import character_accuracy
from ...ocr_files import alto_text, page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data") data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")

@ -5,9 +5,9 @@ import os
import pytest import pytest
from lxml import etree as ET from lxml import etree as ET
from ... import alto_text, page_text
from ...metrics import word_accuracy from ...metrics import word_accuracy
from ...normalize import words from ...normalize import words
from ...ocr_files import alto_text, page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data") data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")

@ -1,5 +1,6 @@
from .util import unzip from .util import unzip
from .. import align, seq_align, distance from ..align import align, seq_align
from ..edit_distance import distance
def test_left_empty(): def test_left_empty():

@ -1,8 +1,6 @@
from __future__ import division, print_function
import unicodedata import unicodedata
from .. import levenshtein, distance from ..edit_distance import levenshtein, distance
def test_levenshtein(): def test_levenshtein():

@ -1,6 +1,6 @@
import unicodedata import unicodedata
from .. import seq_editops, editops from ..edit_distance import seq_editops, editops
def test_trivial(): def test_trivial():

@ -1,11 +1,10 @@
from __future__ import division, print_function
import os import os
import pytest import pytest
from lxml import etree as ET from lxml import etree as ET
from .. import align, page_text from ..align import align
from ..ocr_files import page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")

@ -5,7 +5,8 @@ import os
import pytest import pytest
from lxml import etree as ET from lxml import etree as ET
from .. import distance, page_text, alto_text from ..edit_distance import distance
from ..ocr_files import alto_text, page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")

@ -3,7 +3,7 @@ import os
import pytest import pytest
from lxml import etree as ET from lxml import etree as ET
from .. import page_text from ..ocr_files import page_text
@pytest.mark.parametrize( @pytest.mark.parametrize(

@ -1,13 +1,12 @@
import os import os
import re import re
import lxml.etree as ET
import textwrap import textwrap
import pytest import lxml.etree as ET
from .util import working_directory from .util import working_directory
from .. import alto_namespace, alto_text, page_namespace, page_text, plain_text, text from ..ocr_files import alto_namespace, alto_text, page_namespace, page_text, \
plain_text, text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")

Loading…
Cancel
Save