mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-09 20:00:01 +02:00
🔍 mypy: Use an almost strict mypy configuration, and fix any issues
This commit is contained in:
parent
ad316aeabc
commit
483e809691
11 changed files with 77 additions and 41 deletions
|
@ -28,6 +28,8 @@ repos:
|
|||
hooks:
|
||||
- additional_dependencies:
|
||||
- types-setuptools
|
||||
- types-lxml
|
||||
- numpy # for numpy plugin
|
||||
id: mypy
|
||||
|
||||
- repo: https://gitlab.com/vojko.pribudic/pre-commit-update
|
||||
|
|
|
@ -60,9 +60,20 @@ markers = [
|
|||
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ["numpy.typing.mypy_plugin"]
|
||||
|
||||
ignore_missing_imports = true
|
||||
|
||||
|
||||
strict = true
|
||||
|
||||
disallow_subclassing_any = false
|
||||
# ❗ error: Class cannot subclass "Processor" (has type "Any")
|
||||
disallow_any_generics = false
|
||||
disallow_untyped_defs = false
|
||||
disallow_untyped_calls = false
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
select = ["E", "F", "I"]
|
||||
ignore = [
|
||||
|
|
|
@ -7,5 +7,6 @@ ruff
|
|||
pytest-ruff
|
||||
|
||||
mypy
|
||||
types-lxml
|
||||
types-setuptools
|
||||
pytest-mypy
|
||||
|
|
|
@ -4,8 +4,7 @@ from math import ceil
|
|||
from typing import Optional
|
||||
|
||||
from rapidfuzz.distance import Levenshtein
|
||||
|
||||
from .edit_distance import grapheme_clusters
|
||||
from uniseg.graphemecluster import grapheme_clusters
|
||||
|
||||
|
||||
def align(t1, t2):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import unicodedata
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, TypeVar
|
||||
|
||||
from multimethod import multimethod
|
||||
from uniseg.graphemecluster import grapheme_clusters
|
||||
|
@ -7,6 +7,8 @@ from uniseg.graphemecluster import grapheme_clusters
|
|||
from .edit_distance import distance
|
||||
from .extracted_text import ExtractedText
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@multimethod
|
||||
def character_error_rate_n(
|
||||
|
@ -34,21 +36,24 @@ def character_error_rate_n(
|
|||
def _(reference: str, compared: str) -> Tuple[float, int]:
|
||||
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
|
||||
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
|
||||
return character_error_rate_n(seq1, seq2)
|
||||
cer, n = character_error_rate_n(seq1, seq2)
|
||||
return cer, n
|
||||
|
||||
|
||||
@character_error_rate_n.register
|
||||
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
|
||||
return character_error_rate_n(
|
||||
cer, n = character_error_rate_n(
|
||||
reference.grapheme_clusters, compared.grapheme_clusters
|
||||
)
|
||||
return cer, n
|
||||
|
||||
|
||||
def character_error_rate(reference, compared) -> float:
|
||||
def character_error_rate(reference: T, compared: T) -> float:
|
||||
"""
|
||||
Compute character error rate.
|
||||
|
||||
:return: character error rate
|
||||
"""
|
||||
cer: float
|
||||
cer, _ = character_error_rate_n(reference, compared)
|
||||
return cer
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from collections import Counter
|
||||
from typing import List
|
||||
|
||||
import click
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
@ -76,7 +77,7 @@ def gen_diff_report(
|
|||
if o is not None:
|
||||
o_pos += len(o)
|
||||
|
||||
found_differences = dict(Counter(elem for elem in found_differences))
|
||||
counted_differences = dict(Counter(elem for elem in found_differences))
|
||||
|
||||
return (
|
||||
"""
|
||||
|
@ -87,7 +88,7 @@ def gen_diff_report(
|
|||
""".format(
|
||||
gtx, ocrx
|
||||
),
|
||||
found_differences,
|
||||
counted_differences,
|
||||
)
|
||||
|
||||
|
||||
|
@ -113,7 +114,7 @@ def process(
|
|||
metrics: bool = True,
|
||||
differences: bool = False,
|
||||
textequiv_level: str = "region",
|
||||
):
|
||||
) -> None:
|
||||
"""Check OCR result against GT.
|
||||
|
||||
The @click decorators change the signature of the decorated functions, so we keep
|
||||
|
@ -122,8 +123,8 @@ def process(
|
|||
|
||||
gt_text = extract(gt, textequiv_level=textequiv_level)
|
||||
ocr_text = extract(ocr, textequiv_level=textequiv_level)
|
||||
gt_words: list = list(words_normalized(gt_text))
|
||||
ocr_words: list = list(words_normalized(ocr_text))
|
||||
gt_words: List[str] = list(words_normalized(gt_text))
|
||||
ocr_words: List[str] = list(words_normalized(ocr_text))
|
||||
|
||||
assert isinstance(gt_text, ExtractedText)
|
||||
assert isinstance(ocr_text, ExtractedText)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import click
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
@ -13,8 +14,8 @@ def process(reports_folder, occurrences_threshold=1):
|
|||
wer_list = []
|
||||
cer_sum = 0
|
||||
wer_sum = 0
|
||||
diff_c = {}
|
||||
diff_w = {}
|
||||
diff_c: Dict[str, int] = {}
|
||||
diff_w: Dict[str, int] = {}
|
||||
|
||||
for report in os.listdir(reports_folder):
|
||||
if report.endswith(".json"):
|
||||
|
|
|
@ -9,7 +9,7 @@ from .extracted_text import ExtractedText
|
|||
|
||||
|
||||
@multimethod
|
||||
def distance(seq1: List[str], seq2: List[str]):
|
||||
def distance(seq1: List[str], seq2: List[str]) -> int:
|
||||
"""Compute the Levenshtein edit distance between two lists of grapheme clusters.
|
||||
|
||||
This assumes that the grapheme clusters are already normalized.
|
||||
|
@ -20,7 +20,7 @@ def distance(seq1: List[str], seq2: List[str]):
|
|||
|
||||
|
||||
@distance.register
|
||||
def _(s1: str, s2: str):
|
||||
def _(s1: str, s2: str) -> int:
|
||||
"""Compute the Levenshtein edit distance between two Unicode strings
|
||||
|
||||
Note that this is different from levenshtein() as this function knows about Unicode
|
||||
|
@ -33,7 +33,7 @@ def _(s1: str, s2: str):
|
|||
|
||||
|
||||
@distance.register
|
||||
def _(s1: ExtractedText, s2: ExtractedText):
|
||||
def _(s1: ExtractedText, s2: ExtractedText) -> int:
|
||||
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import re
|
|||
import unicodedata
|
||||
from contextlib import suppress
|
||||
from itertools import repeat
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import attr
|
||||
import numpy as np
|
||||
|
@ -173,10 +173,11 @@ class ExtractedText:
|
|||
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
def text(self) -> str:
|
||||
if self._text is not None:
|
||||
return self._text
|
||||
else:
|
||||
assert self.joiner is not None and self.segments is not None
|
||||
return self.joiner.join(s.text for s in self.segments)
|
||||
|
||||
@functools.cached_property
|
||||
|
@ -186,6 +187,7 @@ class ExtractedText:
|
|||
This property is cached.
|
||||
"""
|
||||
|
||||
assert self.joiner is not None
|
||||
if len(self.joiner) > 0:
|
||||
joiner_grapheme_cluster = list(grapheme_clusters(self.joiner))
|
||||
assert len(joiner_grapheme_cluster) == 1 # see joiner's check above
|
||||
|
@ -203,6 +205,7 @@ class ExtractedText:
|
|||
else:
|
||||
# TODO Test with text extracted at glyph level (joiner == "")
|
||||
clusters = []
|
||||
assert self.segments is not None
|
||||
for seg in self.segments:
|
||||
clusters += seg.grapheme_clusters + self._joiner_grapheme_cluster
|
||||
clusters = clusters[:-1]
|
||||
|
@ -218,6 +221,7 @@ class ExtractedText:
|
|||
else:
|
||||
# Recurse
|
||||
segment_id_for_pos = []
|
||||
assert self.joiner is not None and self.segments is not None
|
||||
for s in self.segments:
|
||||
seg_ids = [s.segment_id_for_pos(i) for i in range(len(s.text))]
|
||||
segment_id_for_pos.extend(seg_ids)
|
||||
|
@ -280,7 +284,7 @@ def invert_dict(d):
|
|||
return {v: k for k, v in d.items()}
|
||||
|
||||
|
||||
def get_textequiv_unicode(text_segment, nsmap) -> str:
|
||||
def get_textequiv_unicode(text_segment: Any, nsmap: Dict[str, str]) -> str:
|
||||
"""Get the TextEquiv/Unicode text of the given PAGE text element."""
|
||||
segment_id = text_segment.attrib["id"]
|
||||
textequivs = text_segment.findall("./page:TextEquiv", namespaces=nsmap)
|
||||
|
@ -304,7 +308,7 @@ def get_first_textequiv(textequivs, segment_id):
|
|||
if np.any(~nan_mask):
|
||||
if np.any(nan_mask):
|
||||
log.warning("TextEquiv without index in %s.", segment_id)
|
||||
index = np.nanargmin(indices)
|
||||
index = int(np.nanargmin(indices))
|
||||
else:
|
||||
# try ordering by conf
|
||||
confidences = np.array([get_attr(te, "conf") for te in textequivs], dtype=float)
|
||||
|
@ -313,7 +317,7 @@ def get_first_textequiv(textequivs, segment_id):
|
|||
"No index attributes, use 'conf' attribute to sort TextEquiv in %s.",
|
||||
segment_id,
|
||||
)
|
||||
index = np.nanargmax(confidences)
|
||||
index = int(np.nanargmax(confidences))
|
||||
else:
|
||||
# fallback to first entry in case of neither index or conf present
|
||||
log.warning("No index attributes, use first TextEquiv in %s.", segment_id)
|
||||
|
@ -321,7 +325,7 @@ def get_first_textequiv(textequivs, segment_id):
|
|||
return textequivs[index]
|
||||
|
||||
|
||||
def get_attr(te, attr_name) -> float:
|
||||
def get_attr(te: Any, attr_name: str) -> float:
|
||||
"""Extract the attribute for the given name.
|
||||
|
||||
Note: currently only handles numeric values!
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import sys
|
||||
from typing import Iterator
|
||||
from typing import Dict, Iterator, Optional
|
||||
|
||||
import chardet
|
||||
from lxml import etree as ET
|
||||
|
@ -10,11 +10,11 @@ from uniseg.graphemecluster import grapheme_clusters
|
|||
from .extracted_text import ExtractedText, normalize_sbb
|
||||
|
||||
|
||||
def alto_namespace(tree: ET.ElementTree) -> str:
|
||||
def alto_namespace(tree: ET._ElementTree) -> Optional[str]:
|
||||
"""Return the ALTO namespace used in the given ElementTree.
|
||||
|
||||
This relies on the assumption that, in any given ALTO file, the root element has the
|
||||
local name "alto". We do not check if the files uses any valid ALTO namespace.
|
||||
local name "alto". We do not check if the file uses any valid ALTO namespace.
|
||||
"""
|
||||
root_name = ET.QName(tree.getroot().tag)
|
||||
if root_name.localname == "alto":
|
||||
|
@ -23,8 +23,15 @@ def alto_namespace(tree: ET.ElementTree) -> str:
|
|||
raise ValueError("Not an ALTO tree")
|
||||
|
||||
|
||||
def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
|
||||
nsmap = {"alto": alto_namespace(tree)}
|
||||
def alto_nsmap(tree: ET._ElementTree) -> Dict[str, str]:
|
||||
alto_ns = alto_namespace(tree)
|
||||
if alto_ns is None:
|
||||
raise ValueError("Could not determine ALTO namespace")
|
||||
return {"alto": alto_ns}
|
||||
|
||||
|
||||
def alto_extract_lines(tree: ET._ElementTree) -> Iterator[ExtractedText]:
|
||||
nsmap = alto_nsmap(tree)
|
||||
for line in tree.iterfind(".//alto:TextLine", namespaces=nsmap):
|
||||
line_id = line.attrib.get("ID")
|
||||
line_text = " ".join(
|
||||
|
@ -37,7 +44,7 @@ def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
|
|||
# FIXME hardcoded SBB normalization
|
||||
|
||||
|
||||
def alto_extract(tree: ET.ElementTree) -> ExtractedText:
|
||||
def alto_extract(tree: ET._ElementTree) -> ExtractedText:
|
||||
"""Extract text from the given ALTO ElementTree."""
|
||||
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None)
|
||||
|
||||
|
@ -98,7 +105,7 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
|
|||
if ET.QName(group.tag).localname in ["OrderedGroup", "OrderedGroupIndexed"]:
|
||||
ro_children = list(group)
|
||||
|
||||
ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children)
|
||||
ro_children = [child for child in ro_children if "index" in child.attrib.keys()]
|
||||
ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"]))
|
||||
elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]:
|
||||
ro_children = list(group)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import unicodedata
|
||||
from typing import Iterable, Tuple
|
||||
from typing import Generator, Iterable, Tuple, TypeVar
|
||||
|
||||
import uniseg.wordbreak
|
||||
from multimethod import multimethod
|
||||
|
@ -7,6 +7,8 @@ from rapidfuzz.distance import Levenshtein
|
|||
|
||||
from .extracted_text import ExtractedText
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Did we patch uniseg.wordbreak.word_break already?
|
||||
word_break_patched = False
|
||||
|
||||
|
@ -32,7 +34,7 @@ def patch_word_break():
|
|||
|
||||
|
||||
@multimethod
|
||||
def words(s: str):
|
||||
def words(s: str) -> Generator[str, None, None]:
|
||||
"""Extract words from a string"""
|
||||
|
||||
global word_break_patched
|
||||
|
@ -61,34 +63,36 @@ def words(s: str):
|
|||
|
||||
|
||||
@words.register
|
||||
def _(s: ExtractedText):
|
||||
return words(s.text)
|
||||
def _(s: ExtractedText) -> Generator[str, None, None]:
|
||||
yield from words(s.text)
|
||||
|
||||
|
||||
@multimethod
|
||||
def words_normalized(s: str):
|
||||
return words(unicodedata.normalize("NFC", s))
|
||||
def words_normalized(s: str) -> Generator[str, None, None]:
|
||||
yield from words(unicodedata.normalize("NFC", s))
|
||||
|
||||
|
||||
@words_normalized.register
|
||||
def _(s: ExtractedText):
|
||||
return words_normalized(s.text)
|
||||
def _(s: ExtractedText) -> Generator[str, None, None]:
|
||||
yield from words_normalized(s.text)
|
||||
|
||||
|
||||
@multimethod
|
||||
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
||||
reference_seq = list(words_normalized(reference))
|
||||
compared_seq = list(words_normalized(compared))
|
||||
return word_error_rate_n(reference_seq, compared_seq)
|
||||
wer, n = word_error_rate_n(reference_seq, compared_seq)
|
||||
return wer, n
|
||||
|
||||
|
||||
@word_error_rate_n.register
|
||||
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
|
||||
return word_error_rate_n(reference.text, compared.text)
|
||||
wer, n = word_error_rate_n(reference.text, compared.text)
|
||||
return wer, n
|
||||
|
||||
|
||||
@word_error_rate_n.register
|
||||
def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
|
||||
def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]:
|
||||
reference_seq = list(reference)
|
||||
compared_seq = list(compared)
|
||||
|
||||
|
@ -102,6 +106,7 @@ def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
|
|||
return d / n, n
|
||||
|
||||
|
||||
def word_error_rate(reference, compared) -> float:
|
||||
def word_error_rate(reference: T, compared: T) -> float:
|
||||
wer: float
|
||||
wer, _ = word_error_rate_n(reference, compared)
|
||||
return wer
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue