🔍 mypy: Use an almost strict mypy configuration, and fix any issues

pull/111/head
Mike Gerber 4 months ago
parent ad316aeabc
commit 483e809691

@ -28,6 +28,8 @@ repos:
hooks: hooks:
- additional_dependencies: - additional_dependencies:
- types-setuptools - types-setuptools
- types-lxml
- numpy # for numpy plugin
id: mypy id: mypy
- repo: https://gitlab.com/vojko.pribudic/pre-commit-update - repo: https://gitlab.com/vojko.pribudic/pre-commit-update

@ -60,9 +60,20 @@ markers = [
[tool.mypy] [tool.mypy]
plugins = ["numpy.typing.mypy_plugin"]
ignore_missing_imports = true ignore_missing_imports = true
strict = true
disallow_subclassing_any = false
# ❗ error: Class cannot subclass "Processor" (has type "Any")
disallow_any_generics = false
disallow_untyped_defs = false
disallow_untyped_calls = false
[tool.ruff] [tool.ruff]
select = ["E", "F", "I"] select = ["E", "F", "I"]
ignore = [ ignore = [

@ -7,5 +7,6 @@ ruff
pytest-ruff pytest-ruff
mypy mypy
types-lxml
types-setuptools types-setuptools
pytest-mypy pytest-mypy

@ -4,8 +4,7 @@ from math import ceil
from typing import Optional from typing import Optional
from rapidfuzz.distance import Levenshtein from rapidfuzz.distance import Levenshtein
from uniseg.graphemecluster import grapheme_clusters
from .edit_distance import grapheme_clusters
def align(t1, t2): def align(t1, t2):

@ -1,5 +1,5 @@
import unicodedata import unicodedata
from typing import List, Tuple from typing import List, Tuple, TypeVar
from multimethod import multimethod from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
@ -7,6 +7,8 @@ from uniseg.graphemecluster import grapheme_clusters
from .edit_distance import distance from .edit_distance import distance
from .extracted_text import ExtractedText from .extracted_text import ExtractedText
T = TypeVar("T")
@multimethod @multimethod
def character_error_rate_n( def character_error_rate_n(
@ -34,21 +36,24 @@ def character_error_rate_n(
def _(reference: str, compared: str) -> Tuple[float, int]: def _(reference: str, compared: str) -> Tuple[float, int]:
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference))) seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared))) seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
return character_error_rate_n(seq1, seq2) cer, n = character_error_rate_n(seq1, seq2)
return cer, n
@character_error_rate_n.register @character_error_rate_n.register
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]: def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return character_error_rate_n( cer, n = character_error_rate_n(
reference.grapheme_clusters, compared.grapheme_clusters reference.grapheme_clusters, compared.grapheme_clusters
) )
return cer, n
def character_error_rate(reference, compared) -> float: def character_error_rate(reference: T, compared: T) -> float:
""" """
Compute character error rate. Compute character error rate.
:return: character error rate :return: character error rate
""" """
cer: float
cer, _ = character_error_rate_n(reference, compared) cer, _ = character_error_rate_n(reference, compared)
return cer return cer

@ -1,5 +1,6 @@
import os import os
from collections import Counter from collections import Counter
from typing import List
import click import click
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
@ -76,7 +77,7 @@ def gen_diff_report(
if o is not None: if o is not None:
o_pos += len(o) o_pos += len(o)
found_differences = dict(Counter(elem for elem in found_differences)) counted_differences = dict(Counter(elem for elem in found_differences))
return ( return (
""" """
@ -87,7 +88,7 @@ def gen_diff_report(
""".format( """.format(
gtx, ocrx gtx, ocrx
), ),
found_differences, counted_differences,
) )
@ -113,7 +114,7 @@ def process(
metrics: bool = True, metrics: bool = True,
differences: bool = False, differences: bool = False,
textequiv_level: str = "region", textequiv_level: str = "region",
): ) -> None:
"""Check OCR result against GT. """Check OCR result against GT.
The @click decorators change the signature of the decorated functions, so we keep The @click decorators change the signature of the decorated functions, so we keep
@ -122,8 +123,8 @@ def process(
gt_text = extract(gt, textequiv_level=textequiv_level) gt_text = extract(gt, textequiv_level=textequiv_level)
ocr_text = extract(ocr, textequiv_level=textequiv_level) ocr_text = extract(ocr, textequiv_level=textequiv_level)
gt_words: list = list(words_normalized(gt_text)) gt_words: List[str] = list(words_normalized(gt_text))
ocr_words: list = list(words_normalized(ocr_text)) ocr_words: List[str] = list(words_normalized(ocr_text))
assert isinstance(gt_text, ExtractedText) assert isinstance(gt_text, ExtractedText)
assert isinstance(ocr_text, ExtractedText) assert isinstance(ocr_text, ExtractedText)

@ -1,5 +1,6 @@
import json import json
import os import os
from typing import Dict
import click import click
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
@ -13,8 +14,8 @@ def process(reports_folder, occurrences_threshold=1):
wer_list = [] wer_list = []
cer_sum = 0 cer_sum = 0
wer_sum = 0 wer_sum = 0
diff_c = {} diff_c: Dict[str, int] = {}
diff_w = {} diff_w: Dict[str, int] = {}
for report in os.listdir(reports_folder): for report in os.listdir(reports_folder):
if report.endswith(".json"): if report.endswith(".json"):

@ -9,7 +9,7 @@ from .extracted_text import ExtractedText
@multimethod @multimethod
def distance(seq1: List[str], seq2: List[str]): def distance(seq1: List[str], seq2: List[str]) -> int:
"""Compute the Levenshtein edit distance between two lists of grapheme clusters. """Compute the Levenshtein edit distance between two lists of grapheme clusters.
This assumes that the grapheme clusters are already normalized. This assumes that the grapheme clusters are already normalized.
@ -20,7 +20,7 @@ def distance(seq1: List[str], seq2: List[str]):
@distance.register @distance.register
def _(s1: str, s2: str): def _(s1: str, s2: str) -> int:
"""Compute the Levenshtein edit distance between two Unicode strings """Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode Note that this is different from levenshtein() as this function knows about Unicode
@ -33,7 +33,7 @@ def _(s1: str, s2: str):
@distance.register @distance.register
def _(s1: ExtractedText, s2: ExtractedText): def _(s1: ExtractedText, s2: ExtractedText) -> int:
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters) return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)

@ -4,7 +4,7 @@ import re
import unicodedata import unicodedata
from contextlib import suppress from contextlib import suppress
from itertools import repeat from itertools import repeat
from typing import List, Optional from typing import Any, Dict, List, Optional
import attr import attr
import numpy as np import numpy as np
@ -173,10 +173,11 @@ class ExtractedText:
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB) normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)
@property @property
def text(self): def text(self) -> str:
if self._text is not None: if self._text is not None:
return self._text return self._text
else: else:
assert self.joiner is not None and self.segments is not None
return self.joiner.join(s.text for s in self.segments) return self.joiner.join(s.text for s in self.segments)
@functools.cached_property @functools.cached_property
@ -186,6 +187,7 @@ class ExtractedText:
This property is cached. This property is cached.
""" """
assert self.joiner is not None
if len(self.joiner) > 0: if len(self.joiner) > 0:
joiner_grapheme_cluster = list(grapheme_clusters(self.joiner)) joiner_grapheme_cluster = list(grapheme_clusters(self.joiner))
assert len(joiner_grapheme_cluster) == 1 # see joiner's check above assert len(joiner_grapheme_cluster) == 1 # see joiner's check above
@ -203,6 +205,7 @@ class ExtractedText:
else: else:
# TODO Test with text extracted at glyph level (joiner == "") # TODO Test with text extracted at glyph level (joiner == "")
clusters = [] clusters = []
assert self.segments is not None
for seg in self.segments: for seg in self.segments:
clusters += seg.grapheme_clusters + self._joiner_grapheme_cluster clusters += seg.grapheme_clusters + self._joiner_grapheme_cluster
clusters = clusters[:-1] clusters = clusters[:-1]
@ -218,6 +221,7 @@ class ExtractedText:
else: else:
# Recurse # Recurse
segment_id_for_pos = [] segment_id_for_pos = []
assert self.joiner is not None and self.segments is not None
for s in self.segments: for s in self.segments:
seg_ids = [s.segment_id_for_pos(i) for i in range(len(s.text))] seg_ids = [s.segment_id_for_pos(i) for i in range(len(s.text))]
segment_id_for_pos.extend(seg_ids) segment_id_for_pos.extend(seg_ids)
@ -280,7 +284,7 @@ def invert_dict(d):
return {v: k for k, v in d.items()} return {v: k for k, v in d.items()}
def get_textequiv_unicode(text_segment, nsmap) -> str: def get_textequiv_unicode(text_segment: Any, nsmap: Dict[str, str]) -> str:
"""Get the TextEquiv/Unicode text of the given PAGE text element.""" """Get the TextEquiv/Unicode text of the given PAGE text element."""
segment_id = text_segment.attrib["id"] segment_id = text_segment.attrib["id"]
textequivs = text_segment.findall("./page:TextEquiv", namespaces=nsmap) textequivs = text_segment.findall("./page:TextEquiv", namespaces=nsmap)
@ -304,7 +308,7 @@ def get_first_textequiv(textequivs, segment_id):
if np.any(~nan_mask): if np.any(~nan_mask):
if np.any(nan_mask): if np.any(nan_mask):
log.warning("TextEquiv without index in %s.", segment_id) log.warning("TextEquiv without index in %s.", segment_id)
index = np.nanargmin(indices) index = int(np.nanargmin(indices))
else: else:
# try ordering by conf # try ordering by conf
confidences = np.array([get_attr(te, "conf") for te in textequivs], dtype=float) confidences = np.array([get_attr(te, "conf") for te in textequivs], dtype=float)
@ -313,7 +317,7 @@ def get_first_textequiv(textequivs, segment_id):
"No index attributes, use 'conf' attribute to sort TextEquiv in %s.", "No index attributes, use 'conf' attribute to sort TextEquiv in %s.",
segment_id, segment_id,
) )
index = np.nanargmax(confidences) index = int(np.nanargmax(confidences))
else: else:
# fallback to first entry in case of neither index or conf present # fallback to first entry in case of neither index or conf present
log.warning("No index attributes, use first TextEquiv in %s.", segment_id) log.warning("No index attributes, use first TextEquiv in %s.", segment_id)
@ -321,7 +325,7 @@ def get_first_textequiv(textequivs, segment_id):
return textequivs[index] return textequivs[index]
def get_attr(te, attr_name) -> float: def get_attr(te: Any, attr_name: str) -> float:
"""Extract the attribute for the given name. """Extract the attribute for the given name.
Note: currently only handles numeric values! Note: currently only handles numeric values!

@ -1,6 +1,6 @@
import os import os
import sys import sys
from typing import Iterator from typing import Dict, Iterator, Optional
import chardet import chardet
from lxml import etree as ET from lxml import etree as ET
@ -10,11 +10,11 @@ from uniseg.graphemecluster import grapheme_clusters
from .extracted_text import ExtractedText, normalize_sbb from .extracted_text import ExtractedText, normalize_sbb
def alto_namespace(tree: ET.ElementTree) -> str: def alto_namespace(tree: ET._ElementTree) -> Optional[str]:
"""Return the ALTO namespace used in the given ElementTree. """Return the ALTO namespace used in the given ElementTree.
This relies on the assumption that, in any given ALTO file, the root element has the This relies on the assumption that, in any given ALTO file, the root element has the
local name "alto". We do not check if the files uses any valid ALTO namespace. local name "alto". We do not check if the file uses any valid ALTO namespace.
""" """
root_name = ET.QName(tree.getroot().tag) root_name = ET.QName(tree.getroot().tag)
if root_name.localname == "alto": if root_name.localname == "alto":
@ -23,8 +23,15 @@ def alto_namespace(tree: ET.ElementTree) -> str:
raise ValueError("Not an ALTO tree") raise ValueError("Not an ALTO tree")
def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]: def alto_nsmap(tree: ET._ElementTree) -> Dict[str, str]:
nsmap = {"alto": alto_namespace(tree)} alto_ns = alto_namespace(tree)
if alto_ns is None:
raise ValueError("Could not determine ALTO namespace")
return {"alto": alto_ns}
def alto_extract_lines(tree: ET._ElementTree) -> Iterator[ExtractedText]:
nsmap = alto_nsmap(tree)
for line in tree.iterfind(".//alto:TextLine", namespaces=nsmap): for line in tree.iterfind(".//alto:TextLine", namespaces=nsmap):
line_id = line.attrib.get("ID") line_id = line.attrib.get("ID")
line_text = " ".join( line_text = " ".join(
@ -37,7 +44,7 @@ def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
# FIXME hardcoded SBB normalization # FIXME hardcoded SBB normalization
def alto_extract(tree: ET.ElementTree) -> ExtractedText: def alto_extract(tree: ET._ElementTree) -> ExtractedText:
"""Extract text from the given ALTO ElementTree.""" """Extract text from the given ALTO ElementTree."""
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None) return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None)
@ -98,7 +105,7 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
if ET.QName(group.tag).localname in ["OrderedGroup", "OrderedGroupIndexed"]: if ET.QName(group.tag).localname in ["OrderedGroup", "OrderedGroupIndexed"]:
ro_children = list(group) ro_children = list(group)
ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children) ro_children = [child for child in ro_children if "index" in child.attrib.keys()]
ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"])) ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"]))
elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]: elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]:
ro_children = list(group) ro_children = list(group)

@ -1,5 +1,5 @@
import unicodedata import unicodedata
from typing import Iterable, Tuple from typing import Generator, Iterable, Tuple, TypeVar
import uniseg.wordbreak import uniseg.wordbreak
from multimethod import multimethod from multimethod import multimethod
@ -7,6 +7,8 @@ from rapidfuzz.distance import Levenshtein
from .extracted_text import ExtractedText from .extracted_text import ExtractedText
T = TypeVar("T")
# Did we patch uniseg.wordbreak.word_break already? # Did we patch uniseg.wordbreak.word_break already?
word_break_patched = False word_break_patched = False
@ -32,7 +34,7 @@ def patch_word_break():
@multimethod @multimethod
def words(s: str): def words(s: str) -> Generator[str, None, None]:
"""Extract words from a string""" """Extract words from a string"""
global word_break_patched global word_break_patched
@ -61,34 +63,36 @@ def words(s: str):
@words.register @words.register
def _(s: ExtractedText): def _(s: ExtractedText) -> Generator[str, None, None]:
return words(s.text) yield from words(s.text)
@multimethod @multimethod
def words_normalized(s: str): def words_normalized(s: str) -> Generator[str, None, None]:
return words(unicodedata.normalize("NFC", s)) yield from words(unicodedata.normalize("NFC", s))
@words_normalized.register @words_normalized.register
def _(s: ExtractedText): def _(s: ExtractedText) -> Generator[str, None, None]:
return words_normalized(s.text) yield from words_normalized(s.text)
@multimethod @multimethod
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
reference_seq = list(words_normalized(reference)) reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared)) compared_seq = list(words_normalized(compared))
return word_error_rate_n(reference_seq, compared_seq) wer, n = word_error_rate_n(reference_seq, compared_seq)
return wer, n
@word_error_rate_n.register @word_error_rate_n.register
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]: def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text) wer, n = word_error_rate_n(reference.text, compared.text)
return wer, n
@word_error_rate_n.register @word_error_rate_n.register
def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]: def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]:
reference_seq = list(reference) reference_seq = list(reference)
compared_seq = list(compared) compared_seq = list(compared)
@ -102,6 +106,7 @@ def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
return d / n, n return d / n, n
def word_error_rate(reference, compared) -> float: def word_error_rate(reference: T, compared: T) -> float:
wer: float
wer, _ = word_error_rate_n(reference, compared) wer, _ = word_error_rate_n(reference, compared)
return wer return wer

Loading…
Cancel
Save