mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-08 19:30:01 +02:00
Fixed some flake8 and mypy issues.
This commit is contained in:
parent
a44a3d4bf2
commit
714b569195
23 changed files with 81 additions and 88 deletions
|
@ -1,3 +0,0 @@
|
|||
from .ocr_files import *
|
||||
from .extracted_text import *
|
||||
from .align import *
|
|
@ -1,4 +1,4 @@
|
|||
from .edit_distance import *
|
||||
from .edit_distance import seq_editops
|
||||
from .normalize import chars_normalized
|
||||
|
||||
|
||||
|
|
|
@ -36,13 +36,9 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
|
|||
html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_)
|
||||
|
||||
if css_classes:
|
||||
return '<span class="{css_classes}" {html_custom_attrs}>{html_t}</span>'.format(
|
||||
css_classes=css_classes,
|
||||
html_t=html_t,
|
||||
html_custom_attrs=html_custom_attrs,
|
||||
)
|
||||
return f"<span class=\"{css_classes}\" {html_custom_attrs}>{html_t}</span>"
|
||||
else:
|
||||
return "{html_t}".format(html_t=html_t)
|
||||
return f"{html_t}"
|
||||
|
||||
if isinstance(gt_in, ExtractedText):
|
||||
if not isinstance(ocr_in, ExtractedText):
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
import os
|
||||
|
||||
import click
|
||||
from ocrd_utils import initLogging
|
||||
|
||||
from .extracted_text import ExtractedText
|
||||
from .ocr_files import extract
|
||||
|
||||
|
||||
|
|
|
@ -12,15 +12,16 @@ from .normalize import chars_normalized
|
|||
|
||||
def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
|
||||
"""Compute the matrix commonly computed to produce the Levenshtein distance.
|
||||
This is also known as the Wagner-Fischer algorithm. The matrix element at the bottom right contains the desired
|
||||
edit distance.
|
||||
This is also known as the Wagner-Fischer algorithm. The matrix element at the bottom
|
||||
right contains the desired edit distance.
|
||||
|
||||
This algorithm is implemented here because we need an implementation that can work with sequences other than
|
||||
strings, e.g. lists of grapheme clusters or lists of word strings.
|
||||
This algorithm is implemented here because we need an implementation that can work
|
||||
with sequences other than strings, e.g. lists of grapheme clusters or lists of word
|
||||
strings.
|
||||
"""
|
||||
|
||||
# Internally, we use a cached version. As the cache only works on hashable parameters, we convert the input
|
||||
# sequences to tuples to make them hashable.
|
||||
# Internally, we use a cached version. As the cache only works on hashable
|
||||
# parameters, we convert the input sequences to tuples to make them hashable.
|
||||
return _levenshtein_matrix(tuple(seq1), tuple(seq2))
|
||||
|
||||
|
||||
|
@ -28,7 +29,8 @@ def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
|
|||
def _levenshtein_matrix(seq1: Tuple, seq2: Tuple):
|
||||
"""Compute the matrix commonly computed to produce the Levenshtein distance.
|
||||
|
||||
This is a LRU cached function not meant to be used directly. Use levenshtein_matrix() instead.
|
||||
This is a LRU cached function not meant to be used directly.
|
||||
Use levenshtein_matrix() instead.
|
||||
"""
|
||||
m = len(seq1)
|
||||
n = len(seq2)
|
||||
|
@ -36,7 +38,7 @@ def _levenshtein_matrix(seq1: Tuple, seq2: Tuple):
|
|||
def from_to(start, stop):
|
||||
return range(start, stop + 1, 1)
|
||||
|
||||
D = np.zeros((m + 1, n + 1), np.int)
|
||||
D = np.zeros((m + 1, n + 1), int)
|
||||
D[0, 0] = 0
|
||||
for i in from_to(1, m):
|
||||
D[i, 0] = i
|
||||
|
@ -75,8 +77,9 @@ def levenshtein_matrix_cache_clear():
|
|||
def distance(s1: str, s2: str):
|
||||
"""Compute the Levenshtein edit distance between two Unicode strings
|
||||
|
||||
Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme
|
||||
clusters. This should be the correct way to compare two Unicode strings.
|
||||
Note that this is different from levenshtein() as this function knows about Unicode
|
||||
normalization and grapheme clusters. This should be the correct way to compare two
|
||||
Unicode strings.
|
||||
"""
|
||||
seq1 = chars_normalized(s1)
|
||||
seq2 = chars_normalized(s2)
|
||||
|
@ -87,8 +90,8 @@ def seq_editops(seq1, seq2):
|
|||
"""
|
||||
Return sequence of edit operations transforming one sequence to another.
|
||||
|
||||
This aims to return the same/similar results as python-Levenshtein's editops(), just generalized to arbitrary
|
||||
sequences.
|
||||
This aims to return the same/similar results as python-Levenshtein's editops(),
|
||||
just generalized to arbitrary sequences.
|
||||
"""
|
||||
seq1 = list(seq1)
|
||||
seq2 = list(seq2)
|
||||
|
|
|
@ -1,5 +1,14 @@
|
|||
from .bag_of_chars_accuracy import *
|
||||
from .bag_of_words_accuracy import *
|
||||
from .character_accuracy import *
|
||||
from .bag_of_chars_accuracy import bag_of_chars_accuracy
|
||||
from .bag_of_words_accuracy import bag_of_words_accuracy
|
||||
from .character_accuracy import character_accuracy
|
||||
from .utils import MetricResult, Weights
|
||||
from .word_accuracy import *
|
||||
from .word_accuracy import word_accuracy
|
||||
|
||||
__all__ = [
|
||||
"bag_of_chars_accuracy",
|
||||
"bag_of_words_accuracy",
|
||||
"character_accuracy",
|
||||
"word_accuracy",
|
||||
"MetricResult",
|
||||
"Weights",
|
||||
]
|
||||
|
|
|
@ -9,9 +9,8 @@ from .utils import bag_accuracy, MetricResult, Weights
|
|||
def bag_of_chars_accuracy(
|
||||
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
|
||||
) -> MetricResult:
|
||||
reference_chars = Counter(grapheme_clusters(normalize("NFC", reference)))
|
||||
compared_chars = Counter(grapheme_clusters(normalize("NFC", compared)))
|
||||
result = bag_accuracy(reference_chars, compared_chars, weights)
|
||||
return MetricResult(
|
||||
**{**result._asdict(), "metric": bag_of_chars_accuracy.__name__}
|
||||
reference_chars: Counter = Counter(grapheme_clusters(normalize("NFC", reference)))
|
||||
compared_chars: Counter = Counter(grapheme_clusters(normalize("NFC", compared)))
|
||||
return bag_accuracy(
|
||||
reference_chars, compared_chars, weights, bag_of_chars_accuracy.__name__
|
||||
)
|
||||
|
|
|
@ -7,9 +7,8 @@ from ..normalize import words_normalized
|
|||
def bag_of_words_accuracy(
|
||||
reference: str, compared: str, weights: Weights = Weights(1, 0, 1)
|
||||
) -> MetricResult:
|
||||
reference_words = Counter(words_normalized(reference))
|
||||
compared_words = Counter(words_normalized(compared))
|
||||
result = bag_accuracy(reference_words, compared_words, weights)
|
||||
return MetricResult(
|
||||
**{**result._asdict(), "metric": bag_of_words_accuracy.__name__}
|
||||
reference_words: Counter = Counter(words_normalized(reference))
|
||||
compared_words: Counter = Counter(words_normalized(compared))
|
||||
return bag_accuracy(
|
||||
reference_words, compared_words, weights, bag_of_words_accuracy.__name__
|
||||
)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from __future__ import division
|
||||
|
||||
from .utils import MetricResult, Weights
|
||||
from .. import distance
|
||||
from ..edit_distance import distance
|
||||
from ..normalize import chars_normalized
|
||||
|
||||
|
||||
|
|
|
@ -37,10 +37,7 @@ class MetricResult(NamedTuple):
|
|||
We deviate from the builtin _asdict() function by including our properties.
|
||||
"""
|
||||
return {
|
||||
**{
|
||||
key: value
|
||||
for key, value in self._asdict().items()
|
||||
},
|
||||
**{key: value for key, value in self._asdict().items()},
|
||||
"accuracy": self.accuracy,
|
||||
"error_rate": self.error_rate,
|
||||
"weights": self.weights._asdict(),
|
||||
|
@ -48,7 +45,10 @@ class MetricResult(NamedTuple):
|
|||
|
||||
|
||||
def bag_accuracy(
|
||||
reference: Counter, compared: Counter, weights: Weights
|
||||
reference: Counter,
|
||||
compared: Counter,
|
||||
weights: Weights,
|
||||
metric: str = "bag_accuracy",
|
||||
) -> MetricResult:
|
||||
"""Calculates the the weighted errors for two bags (Counter).
|
||||
|
||||
|
@ -61,6 +61,7 @@ def bag_accuracy(
|
|||
:param reference: Bag used as reference (ground truth).
|
||||
:param compared: Bag used to compare (ocr).
|
||||
:param weights: Weights/costs for editing operations.
|
||||
:param metric: Name of the (original) metric.
|
||||
:return: NamedTuple representing the results of this metric.
|
||||
"""
|
||||
n_ref = sum(reference.values())
|
||||
|
@ -77,7 +78,7 @@ def bag_accuracy(
|
|||
+ weights.replacements * replacements
|
||||
)
|
||||
return MetricResult(
|
||||
metric=bag_accuracy.__name__,
|
||||
metric=metric,
|
||||
weights=weights,
|
||||
weighted_errors=weighted_errors,
|
||||
reference_elements=n_ref,
|
||||
|
|
|
@ -1,26 +1,19 @@
|
|||
import unicodedata
|
||||
from typing import Union
|
||||
|
||||
import uniseg.wordbreak
|
||||
from uniseg.graphemecluster import grapheme_clusters
|
||||
|
||||
from .extracted_text import ExtractedText
|
||||
|
||||
|
||||
def chars_normalized(s: Union[str, ExtractedText]):
|
||||
def chars_normalized(s: str):
|
||||
"""Normalize characters in string."""
|
||||
if isinstance(s, ExtractedText):
|
||||
s = s.text
|
||||
return list(grapheme_clusters(unicodedata.normalize("NFC", s)))
|
||||
|
||||
|
||||
def words(s: Union[str, ExtractedText]):
|
||||
def words(s: str):
|
||||
"""Extract words from a string"""
|
||||
|
||||
if isinstance(s, ExtractedText):
|
||||
s = s.text
|
||||
|
||||
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
|
||||
# Patch uniseg.wordbreak.word_break to deal with our private use characters.
|
||||
# See also
|
||||
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
|
||||
old_word_break = uniseg.wordbreak.word_break
|
||||
|
||||
|
@ -54,8 +47,6 @@ def words(s: Union[str, ExtractedText]):
|
|||
yield word
|
||||
|
||||
|
||||
def words_normalized(s: Union[str, ExtractedText]):
|
||||
def words_normalized(s: str):
|
||||
"""Extract words from string and normalize them."""
|
||||
if isinstance(s, ExtractedText):
|
||||
s = s.text
|
||||
return words(unicodedata.normalize("NFC", s))
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
from __future__ import division, print_function
|
||||
|
||||
from typing import Iterator
|
||||
from warnings import warn
|
||||
import sys
|
||||
|
||||
from lxml import etree as ET
|
||||
|
@ -13,7 +10,8 @@ from .extracted_text import ExtractedText, normalize_sbb
|
|||
def alto_namespace(tree: ET.ElementTree) -> str:
|
||||
"""Return the ALTO namespace used in the given ElementTree.
|
||||
|
||||
This relies on the assumption that, in any given ALTO file, the root element has the local name "alto". We do not
|
||||
This relies on the assumption that, in any given ALTO file,
|
||||
the root element has the local name "alto". We do not
|
||||
check if the files uses any valid ALTO namespace.
|
||||
"""
|
||||
root_name = ET.QName(tree.getroot().tag)
|
||||
|
@ -47,8 +45,9 @@ def alto_text(tree):
|
|||
def page_namespace(tree):
|
||||
"""Return the PAGE content namespace used in the given ElementTree.
|
||||
|
||||
This relies on the assumption that, in any given PAGE content file, the root element has the local name "PcGts". We
|
||||
do not check if the files uses any valid PAGE namespace.
|
||||
This relies on the assumption that, in any given PAGE content file,
|
||||
the root element has the local name "PcGts".
|
||||
We do not check if the files uses any valid PAGE namespace.
|
||||
"""
|
||||
root_name = ET.QName(tree.getroot().tag)
|
||||
if root_name.localname == "PcGts":
|
||||
|
@ -102,9 +101,13 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
for ro_child in ro_children:
|
||||
if ET.QName(ro_child.tag).localname in ["OrderedGroup", "OrderedGroupIndexed", "UnorderedGroup", "UnorderedGroupIndexed"]:
|
||||
if ET.QName(ro_child.tag).localname in [
|
||||
"OrderedGroup",
|
||||
"OrderedGroupIndexed",
|
||||
"UnorderedGroup",
|
||||
"UnorderedGroupIndexed",
|
||||
]:
|
||||
regions.extend(
|
||||
extract_texts_from_reading_order_group(
|
||||
ro_child, tree, nsmap, textequiv_level
|
||||
|
|
|
@ -77,5 +77,6 @@ class OcrdDinglehopperEvaluate(Processor):
|
|||
# Clear cache between files
|
||||
levenshtein_matrix_cache_clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ocrd_dinglehopper()
|
||||
|
|
|
@ -6,7 +6,8 @@ import pytest
|
|||
from lxml import etree as ET
|
||||
from uniseg.graphemecluster import grapheme_clusters
|
||||
|
||||
from .. import seq_align, ExtractedText
|
||||
from ..align import seq_align
|
||||
from ..extracted_text import ExtractedText
|
||||
|
||||
|
||||
def test_text():
|
||||
|
|
|
@ -6,8 +6,8 @@ import pytest
|
|||
from lxml import etree as ET
|
||||
from uniseg.graphemecluster import grapheme_clusters
|
||||
|
||||
from ... import page_text, alto_text
|
||||
from ...metrics import character_accuracy
|
||||
from ...ocr_files import alto_text, page_text
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
|
||||
|
||||
|
|
|
@ -5,9 +5,9 @@ import os
|
|||
import pytest
|
||||
from lxml import etree as ET
|
||||
|
||||
from ... import alto_text, page_text
|
||||
from ...metrics import word_accuracy
|
||||
from ...normalize import words
|
||||
from ...ocr_files import alto_text, page_text
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../", "data")
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .util import unzip
|
||||
from .. import align, seq_align, distance
|
||||
from ..align import align, seq_align
|
||||
from ..edit_distance import distance
|
||||
|
||||
|
||||
def test_left_empty():
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from __future__ import division, print_function
|
||||
|
||||
import unicodedata
|
||||
|
||||
from .. import levenshtein, distance
|
||||
from ..edit_distance import levenshtein, distance
|
||||
|
||||
|
||||
def test_levenshtein():
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import unicodedata
|
||||
|
||||
from .. import seq_editops, editops
|
||||
from ..edit_distance import seq_editops, editops
|
||||
|
||||
|
||||
def test_trivial():
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
from __future__ import division, print_function
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from lxml import etree as ET
|
||||
|
||||
from .. import align, page_text
|
||||
from ..align import align
|
||||
from ..ocr_files import page_text
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@ import os
|
|||
import pytest
|
||||
from lxml import etree as ET
|
||||
|
||||
from .. import distance, page_text, alto_text
|
||||
from ..edit_distance import distance
|
||||
from ..ocr_files import alto_text, page_text
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import pytest
|
||||
from lxml import etree as ET
|
||||
|
||||
from .. import page_text
|
||||
from ..ocr_files import page_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
import lxml.etree as ET
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
import lxml.etree as ET
|
||||
|
||||
from .util import working_directory
|
||||
from .. import alto_namespace, alto_text, page_namespace, page_text, plain_text, text
|
||||
from ..ocr_files import alto_namespace, alto_text, page_namespace, page_text, \
|
||||
plain_text, text
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue