1
0
Fork 0
mirror of https://github.com/qurator-spk/dinglehopper.git synced 2025-06-08 19:30:01 +02:00

Fixed some flake8 and mypy issues.

This commit is contained in:
Benjamin Rosemann 2021-06-11 16:09:19 +02:00
parent a44a3d4bf2
commit 714b569195
23 changed files with 81 additions and 88 deletions

View file

@ -1,3 +0,0 @@
from .ocr_files import *
from .extracted_text import *
from .align import *

View file

@ -1,4 +1,4 @@
from .edit_distance import *
from .edit_distance import seq_editops
from .normalize import chars_normalized

View file

@ -36,13 +36,9 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_)
if css_classes:
return '<span class="{css_classes}" {html_custom_attrs}>{html_t}</span>'.format(
css_classes=css_classes,
html_t=html_t,
html_custom_attrs=html_custom_attrs,
)
return f"<span class=\"{css_classes}\" {html_custom_attrs}>{html_t}</span>"
else:
return "{html_t}".format(html_t=html_t)
return f"{html_t}"
if isinstance(gt_in, ExtractedText):
if not isinstance(ocr_in, ExtractedText):

View file

@ -1,9 +1,6 @@
import os
import click
from ocrd_utils import initLogging
from .extracted_text import ExtractedText
from .ocr_files import extract

View file

@ -12,15 +12,16 @@ from .normalize import chars_normalized
def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
"""Compute the matrix commonly computed to produce the Levenshtein distance.
This is also known as the Wagner-Fischer algorithm. The matrix element at the bottom right contains the desired
edit distance.
This is also known as the Wagner-Fischer algorithm. The matrix element at the bottom
right contains the desired edit distance.
This algorithm is implemented here because we need an implementation that can work with sequences other than
strings, e.g. lists of grapheme clusters or lists of word strings.
This algorithm is implemented here because we need an implementation that can work
with sequences other than strings, e.g. lists of grapheme clusters or lists of word
strings.
"""
# Internally, we use a cached version. As the cache only works on hashable parameters, we convert the input
# sequences to tuples to make them hashable.
# Internally, we use a cached version. As the cache only works on hashable
# parameters, we convert the input sequences to tuples to make them hashable.
return _levenshtein_matrix(tuple(seq1), tuple(seq2))
@ -28,7 +29,8 @@ def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
def _levenshtein_matrix(seq1: Tuple, seq2: Tuple):
"""Compute the matrix commonly computed to produce the Levenshtein distance.
This is a LRU cached function not meant to be used directly. Use levenshtein_matrix() instead.
This is a LRU cached function not meant to be used directly.
Use levenshtein_matrix() instead.
"""
m = len(seq1)
n = len(seq2)
@ -36,7 +38,7 @@ def _levenshtein_matrix(seq1: Tuple, seq2: Tuple):
def from_to(start, stop):
return range(start, stop + 1, 1)
D = np.zeros((m + 1, n + 1), np.int)
D = np.zeros((m + 1, n + 1), int)
D[0, 0] = 0
for i in from_to(1, m):
D[i, 0] = i
@ -75,8 +77,9 @@ def levenshtein_matrix_cache_clear():
def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme
clusters. This should be the correct way to compare two Unicode strings.
Note that this is different from levenshtein() as this function knows about Unicode
normalization and grapheme clusters. This should be the correct way to compare two
Unicode strings.
"""
seq1 = chars_normalized(s1)
seq2 = chars_normalized(s2)
@ -87,8 +90,8 @@ def seq_editops(seq1, seq2):
"""
Return sequence of edit operations transforming one sequence to another.
This aims to return the same/similar results as python-Levenshtein's editops(), just generalized to arbitrary
sequences.
This aims to return the same/similar results as python-Levenshtein's editops(),
just generalized to arbitrary sequences.
"""
seq1 = list(seq1)
seq2 = list(seq2)

View file

@ -1,5 +1,14 @@
from .bag_of_chars_accuracy import *
from .bag_of_words_accuracy import *
from .character_accuracy import *
from .bag_of_chars_accuracy import bag_of_chars_accuracy
from .bag_of_words_accuracy import bag_of_words_accuracy
from .character_accuracy import character_accuracy
from .utils import MetricResult, Weights
from .word_accuracy import *
from .word_accuracy import word_accuracy
__all__ = [
"bag_of_chars_accuracy",
"bag_of_words_accuracy",
"character_accuracy",
"word_accuracy",
"MetricResult",
"Weights",
]

View file

@ -9,9 +9,8 @@ from .utils import bag_accuracy, MetricResult, Weights
def bag_of_chars_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult:
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
result = bag_accuracy(reference_chars, compared_chars, weights)
return MetricResult(
**{**result._asdict(), "metric": bag_of_chars_accuracy.__name__}
reference_chars: Counter = Counter(grapheme_clusters(normalize("NFC", reference)))
compared_chars: Counter = Counter(grapheme_clusters(normalize("NFC", compared)))
return bag_accuracy(
reference_chars, compared_chars, weights, bag_of_chars_accuracy.__name__
)

View file

@ -7,9 +7,8 @@ from ..normalize import words_normalized
def bag_of_words_accuracy(
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
) -> MetricResult:
reference_words = Counter(words_normalized(reference))
compared_words = Counter(words_normalized(compared))
result = bag_accuracy(reference_words, compared_words, weights)
return MetricResult(
**{**result._asdict(), "metric": bag_of_words_accuracy.__name__}
reference_words: Counter = Counter(words_normalized(reference))
compared_words: Counter = Counter(words_normalized(compared))
return bag_accuracy(
reference_words, compared_words, weights, bag_of_words_accuracy.__name__
)

View file

@ -1,7 +1,5 @@
from __future__ import division
from .utils import MetricResult, Weights
from .. import distance
from ..edit_distance import distance
from ..normalize import chars_normalized

View file

@ -37,10 +37,7 @@ class MetricResult(NamedTuple):
We deviate from the builtin _asdict() function by including our properties.
"""
return {
**{
key: value
for key, value in self._asdict().items()
},
**{key: value for key, value in self._asdict().items()},
"accuracy": self.accuracy,
"error_rate": self.error_rate,
"weights": self.weights._asdict(),
@ -48,7 +45,10 @@ class MetricResult(NamedTuple):
def bag_accuracy(
reference: Counter, compared: Counter, weights: Weights
reference: Counter,
compared: Counter,
weights: Weights,
metric: str = "bag_accuracy",
) -> MetricResult:
"""Calculates the the weighted errors for two bags (Counter).
@ -61,6 +61,7 @@ def bag_accuracy(
:param reference: Bag used as reference (ground truth).
:param compared: Bag used to compare (ocr).
:param weights: Weights/costs for editing operations.
:param metric: Name of the (original) metric.
:return: NamedTuple representing the results of this metric.
"""
n_ref = sum(reference.values())
@ -77,7 +78,7 @@ def bag_accuracy(
+ weights.replacements * replacements
)
return MetricResult(
metric=bag_accuracy.__name__,
metric=metric,
weights=weights,
weighted_errors=weighted_errors,
reference_elements=n_ref,

View file

@ -1,26 +1,19 @@
import unicodedata
from typing import Union
import uniseg.wordbreak
from uniseg.graphemecluster import grapheme_clusters
from .extracted_text import ExtractedText
def chars_normalized(s: Union[str, ExtractedText]):
def chars_normalized(s: str):
"""Normalize characters in string."""
if isinstance(s, ExtractedText):
s = s.text
return list(grapheme_clusters(unicodedata.normalize("NFC", s)))
def words(s: Union[str, ExtractedText]):
def words(s: str):
"""Extract words from a string"""
if isinstance(s, ExtractedText):
s = s.text
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
# Patch uniseg.wordbreak.word_break to deal with our private use characters.
# See also
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
old_word_break = uniseg.wordbreak.word_break
@ -54,8 +47,6 @@ def words(s: Union[str, ExtractedText]):
yield word
def words_normalized(s: Union[str, ExtractedText]):
def words_normalized(s: str):
"""Extract words from string and normalize them."""
if isinstance(s, ExtractedText):
s = s.text
return words(unicodedata.normalize("NFC", s))

View file

@ -1,7 +1,4 @@
from __future__ import division, print_function
from typing import Iterator
from warnings import warn
import sys
from lxml import etree as ET
@ -13,7 +10,8 @@ from .extracted_text import ExtractedText, normalize_sbb
def alto_namespace(tree: ET.ElementTree) -> str:
"""Return the ALTO namespace used in the given ElementTree.
This relies on the assumption that, in any given ALTO file, the root element has the local name "alto". We do not
This relies on the assumption that, in any given ALTO file,
the root element has the local name "alto". We do not
check if the files uses any valid ALTO namespace.
"""
root_name = ET.QName(tree.getroot().tag)
@ -47,8 +45,9 @@ def alto_text(tree):
def page_namespace(tree):
"""Return the PAGE content namespace used in the given ElementTree.
This relies on the assumption that, in any given PAGE content file, the root element has the local name "PcGts". We
do not check if the files uses any valid PAGE namespace.
This relies on the assumption that, in any given PAGE content file,
the root element has the local name "PcGts".
We do not check if the files uses any valid PAGE namespace.
"""
root_name = ET.QName(tree.getroot().tag)
if root_name.localname == "PcGts":
@ -102,9 +101,13 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
else:
raise NotImplementedError
for ro_child in ro_children:
if ET.QName(ro_child.tag).localname in ["OrderedGroup", "OrderedGroupIndexed", "UnorderedGroup", "UnorderedGroupIndexed"]:
if ET.QName(ro_child.tag).localname in [
"OrderedGroup",
"OrderedGroupIndexed",
"UnorderedGroup",
"UnorderedGroupIndexed",
]:
regions.extend(
extract_texts_from_reading_order_group(
ro_child, tree, nsmap, textequiv_level

View file

@ -77,5 +77,6 @@ class OcrdDinglehopperEvaluate(Processor):
# Clear cache between files
levenshtein_matrix_cache_clear()
if __name__ == "__main__":
ocrd_dinglehopper()

View file

@ -6,7 +6,8 @@ import pytest
from lxml import etree as ET
from uniseg.graphemecluster import grapheme_clusters
from .. import seq_align, ExtractedText
from ..align import seq_align
from ..extracted_text import ExtractedText
def test_text():

View file

@ -6,8 +6,8 @@ import pytest
from lxml import etree as ET
from uniseg.graphemecluster import grapheme_clusters
from ... import page_text, alto_text
from ...metrics import character_accuracy
from ...ocr_files import alto_text, page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")

View file

@ -5,9 +5,9 @@ import os
import pytest
from lxml import etree as ET
from ... import alto_text, page_text
from ...metrics import word_accuracy
from ...normalize import words
from ...ocr_files import alto_text, page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")

View file

@ -1,5 +1,6 @@
from .util import unzip
from .. import align, seq_align, distance
from ..align import align, seq_align
from ..edit_distance import distance
def test_left_empty():

View file

@ -1,8 +1,6 @@
from __future__ import division, print_function
import unicodedata
from .. import levenshtein, distance
from ..edit_distance import levenshtein, distance
def test_levenshtein():

View file

@ -1,6 +1,6 @@
import unicodedata
from .. import seq_editops, editops
from ..edit_distance import seq_editops, editops
def test_trivial():

View file

@ -1,11 +1,10 @@
from __future__ import division, print_function
import os
import pytest
from lxml import etree as ET
from .. import align, page_text
from ..align import align
from ..ocr_files import page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")

View file

@ -5,7 +5,8 @@ import os
import pytest
from lxml import etree as ET
from .. import distance, page_text, alto_text
from ..edit_distance import distance
from ..ocr_files import alto_text, page_text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")

View file

@ -3,7 +3,7 @@ import os
import pytest
from lxml import etree as ET
from .. import page_text
from ..ocr_files import page_text
@pytest.mark.parametrize(

View file

@ -1,13 +1,12 @@
import os
import re
import lxml.etree as ET
import textwrap
import pytest
import lxml.etree as ET
from .util import working_directory
from .. import alto_namespace, alto_text, page_namespace, page_text, plain_text, text
from ..ocr_files import alto_namespace, alto_text, page_namespace, page_text, \
plain_text, text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")