Fixed some flake8 and mypy issues.

pull/60/head
Benjamin Rosemann 4 years ago
parent a44a3d4bf2
commit 714b569195

@ -1,3 +0,0 @@
from .ocr_files import *
from .extracted_text import *
from .align import *

@ -1,4 +1,4 @@
from .edit_distance import *
from .edit_distance import seq_editops
from .normalize import chars_normalized

@ -36,13 +36,9 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_)
if css_classes:
return '<span class="{css_classes}" {html_custom_attrs}>{html_t}</span>'.format(
css_classes=css_classes,
html_t=html_t,
html_custom_attrs=html_custom_attrs,
)
return f"<span class=\"{css_classes}\" {html_custom_attrs}>{html_t}</span>"
else:
return "{html_t}".format(html_t=html_t)
return f"{html_t}"
if isinstance(gt_in, ExtractedText):
if not isinstance(ocr_in, ExtractedText):

@ -1,9 +1,6 @@
import os
import click
from ocrd_utils import initLogging
from .extracted_text import ExtractedText
from .ocr_files import extract

@ -12,15 +12,16 @@ from .normalize import chars_normalized
def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
"""Compute the matrix commonly computed to produce the Levenshtein distance.
This is also known as the Wagner-Fischer algorithm. The matrix element at the bottom right contains the desired
edit distance.
This is also known as the Wagner-Fischer algorithm. The matrix element at the bottom
right contains the desired edit distance.
This algorithm is implemented here because we need an implementation that can work with sequences other than
strings, e.g. lists of grapheme clusters or lists of word strings.
This algorithm is implemented here because we need an implementation that can work
with sequences other than strings, e.g. lists of grapheme clusters or lists of word
strings.
"""
# Internally, we use a cached version. As the cache only works on hashable parameters, we convert the input
# sequences to tuples to make them hashable.
# Internally, we use a cached version. As the cache only works on hashable
# parameters, we convert the input sequences to tuples to make them hashable.
return _levenshtein_matrix(tuple(seq1), tuple(seq2))
@ -28,7 +29,8 @@ def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
def _levenshtein_matrix(seq1: Tuple, seq2: Tuple):
"""Compute the matrix commonly computed to produce the Levenshtein distance.
This is a LRU cached function not meant to be used directly. Use levenshtein_matrix() instead.
This is a LRU cached function not meant to be used directly.
Use levenshtein_matrix() instead.
"""
m = len(seq1)
n = len(seq2)
@ -36,7 +38,7 @@ def _levenshtein_matrix(seq1: Tuple, seq2: Tuple):
def from_to(start, stop):
return range(start, stop + 1, 1)
D = np.zeros((m + 1, n + 1), np.int)
D = np.zeros((m + 1, n + 1), int)
D[0, 0] = 0
for i in from_to(1, m):
D[i, 0] = i
@ -75,8 +77,9 @@ def levenshtein_matrix_cache_clear():
def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme
clusters. This should be the correct way to compare two Unicode strings.
Note that this is different from levenshtein() as this function knows about Unicode
normalization and grapheme clusters. This should be the correct way to compare two
Unicode strings.
"""
seq1 = chars_normalized(s1)
seq2 = chars_normalized(s2)
@ -87,8 +90,8 @@ def seq_editops(seq1, seq2):
"""
Return sequence of edit operations transforming one sequence to another.
This aims to return the same/similar results as python-Levenshtein's editops(), just generalized to arbitrary
sequences.
This aims to return the same/similar results as python-Levenshtein's editops(),
just generalized to arbitrary sequences.
"""
seq1 = list(seq1)
seq2 = list(seq2)

@ -1,5 +1,14 @@
from .bag_of_chars_accuracy import *
from .bag_of_words_accuracy import *
from .character_accuracy import *
from .bag_of_chars_accuracy import bag_of_chars_accuracy
from .bag_of_words_accuracy import bag_of_words_accuracy
from .character_accuracy import character_accuracy
from .utils import MetricResult, Weights
from .word_accuracy import *
from .word_accuracy import word_accuracy
__all__ = [
"bag_of_chars_accuracy",
"bag_of_words_accuracy",
"character_accuracy",
"word_accuracy",
"MetricResult",
"Weights",
]

@ -9,9 +9,8 @@ from .utils import bag_accuracy, MetricResult, Weights
def bag_of_chars_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult:
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
result = bag_accuracy(reference_chars, compared_chars, weights)
return MetricResult(
**{**result._asdict(), "metric": bag_of_chars_accuracy.__name__}
reference_chars: Counter = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars: Counter = Counter(grapheme_clusters(normalize("NFC", compared)))
return bag_accuracy(
reference_chars, compared_chars, weights, bag_of_chars_accuracy.__name__
)

@ -7,9 +7,8 @@ from ..normalize import words_normalized
def bag_of_words_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult:
reference_words = Counter(words_normalized(reference))
compared_words = Counter(words_normalized(compared))
result = bag_accuracy(reference_words, compared_words, weights)
return MetricResult(
**{**result._asdict(), "metric": bag_of_words_accuracy.__name__}
reference_words: Counter = Counter(words_normalized(reference))
compared_words: Counter = Counter(words_normalized(compared))
return bag_accuracy(
reference_words, compared_words, weights, bag_of_words_accuracy.__name__
)

@ -1,7 +1,5 @@
from __future__ import division
from .utils import MetricResult, Weights
from .. import distance
from ..edit_distance import distance
from ..normalize import chars_normalized

@ -37,10 +37,7 @@ class MetricResult(NamedTuple):
We deviate from the builtin _asdict() function by including our properties.
"""
return {
**{
key: value
for key, value in self._asdict().items()
},
**{key: value for key, value in self._asdict().items()},
"accuracy": self.accuracy,
"error_rate": self.error_rate,
"weights": self.weights._asdict(),
@ -48,7 +45,10 @@ class MetricResult(NamedTuple):
def bag_accuracy(
reference: Counter, compared: Counter, weights: Weights
reference: Counter,
compared: Counter,
weights: Weights,
metric: str = "bag_accuracy",
) -> MetricResult:
"""Calculates the the weighted errors for two bags (Counter).
@ -61,6 +61,7 @@ def bag_accuracy(
:param reference: Bag used as reference (ground truth).
:param compared: Bag used to compare (ocr).
:param weights: Weights/costs for editing operations.
:param metric: Name of the (original) metric.
:return: NamedTuple representing the results of this metric.
"""
n_ref = sum(reference.values())
@ -77,7 +78,7 @@ def bag_accuracy(
+ weights.replacements * replacements
)
return MetricResult(
metric=bag_accuracy.__name__,
metric=metric,
weights=weights,
weighted_errors=weighted_errors,
reference_elements=n_ref,

@ -1,26 +1,19 @@
import unicodedata
from typing import Union
import uniseg.wordbreak
from uniseg.graphemecluster import grapheme_clusters
from .extracted_text import ExtractedText
def chars_normalized(s: Union[str, ExtractedText]):
def chars_normalized(s: str):
"""Normalize characters in string."""
if isinstance(s, ExtractedText):
s = s.text
return list(grapheme_clusters(unicodedata.normalize("NFC", s)))
def words(s: Union[str, ExtractedText]):
def words(s: str):
"""Extract words from a string"""
if isinstance(s, ExtractedText):
s = s.text
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
# Patch uniseg.wordbreak.word_break to deal with our private use characters.
# See also
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
old_word_break = uniseg.wordbreak.word_break
@ -54,8 +47,6 @@ def words(s: Union[str, ExtractedText]):
yield word
def words_normalized(s: Union[str, ExtractedText]):
def words_normalized(s: str):
"""Extract words from string and normalize them."""
if isinstance(s, ExtractedText):
s = s.text
return words(unicodedata.normalize("NFC", s))

@ -1,7 +1,4 @@
from __future__ import division, print_function
from typing import Iterator
from warnings import warn
import sys
from lxml import etree as ET
@ -13,7 +10,8 @@ from .extracted_text import ExtractedText, normalize_sbb
def alto_namespace(tree: ET.ElementTree) -> str:
"""Return the ALTO namespace used in the given ElementTree.
This relies on the assumption that, in any given ALTO file, the root element has the local name "alto". We do not
This relies on the assumption that, in any given ALTO file,
the root element has the local name "alto". We do not
check if the files uses any valid ALTO namespace.
"""
root_name = ET.QName(tree.getroot().tag)
@ -47,8 +45,9 @@ def alto_text(tree):
def page_namespace(tree):
"""Return the PAGE content namespace used in the given ElementTree.
This relies on the assumption that, in any given PAGE content file, the root element has the local name "PcGts". We
do not check if the files uses any valid PAGE namespace.
This relies on the assumption that, in any given PAGE content file,
the root element has the local name "PcGts".
We do not check if the files uses any valid PAGE namespace.
"""
root_name = ET.QName(tree.getroot().tag)
if root_name.localname == "PcGts":
@ -102,9 +101,13 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
else:
raise NotImplementedError
for ro_child in ro_children:
if ET.QName(ro_child.tag).localname in ["OrderedGroup", "OrderedGroupIndexed", "UnorderedGroup", "UnorderedGroupIndexed"]:
if ET.QName(ro_child.tag).localname in [
"OrderedGroup",
"OrderedGroupIndexed",
"UnorderedGroup",
"UnorderedGroupIndexed",
]:
regions.extend(
extract_texts_from_reading_order_group(
ro_child, tree, nsmap, textequiv_level

@ -77,5 +77,6 @@ class OcrdDinglehopperEvaluate(Processor):
# Clear cache between files
levenshtein_matrix_cache_clear()
if __name__ == "__main__":
ocrd_dinglehopper()

@ -6,7 +6,8 @@ import pytest
from lxml import etree as ET
from uniseg.graphemecluster import grapheme_clusters
from .. import seq_align, ExtractedText
from ..align import seq_align
from ..extracted_text import ExtractedText
def test_text():

@ -6,8 +6,8 @@ import pytest
from lxml import etree as ET
from uniseg.graphemecluster import grapheme_clusters
from ... import page_text, alto_text
from ...metrics import character_accuracy
from ...ocr_files import alto_text, page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")

@ -5,9 +5,9 @@ import os
import pytest
from lxml import etree as ET
from ... import alto_text, page_text
from ...metrics import word_accuracy
from ...normalize import words
from ...ocr_files import alto_text, page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")

@ -1,5 +1,6 @@
from .util import unzip
from .. import align, seq_align, distance
from ..align import align, seq_align
from ..edit_distance import distance
def test_left_empty():

@ -1,8 +1,6 @@
from __future__ import division, print_function
import unicodedata
from .. import levenshtein, distance
from ..edit_distance import levenshtein, distance
def test_levenshtein():

@ -1,6 +1,6 @@
import unicodedata
from .. import seq_editops, editops
from ..edit_distance import seq_editops, editops
def test_trivial():

@ -1,11 +1,10 @@
from __future__ import division, print_function
import os
import pytest
from lxml import etree as ET
from .. import align, page_text
from ..align import align
from ..ocr_files import page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")

@ -5,7 +5,8 @@ import os
import pytest
from lxml import etree as ET
from .. import distance, page_text, alto_text
from ..edit_distance import distance
from ..ocr_files import alto_text, page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")

@ -3,7 +3,7 @@ import os
import pytest
from lxml import etree as ET
from .. import page_text
from ..ocr_files import page_text
@pytest.mark.parametrize(

@ -1,13 +1,12 @@
import os
import re
import lxml.etree as ET
import textwrap
import pytest
import lxml.etree as ET
from .util import working_directory
from .. import alto_namespace, alto_text, page_namespace, page_text, plain_text, text
from ..ocr_files import alto_namespace, alto_text, page_namespace, page_text, \
plain_text, text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")

Loading…
Cancel
Save