🔍 mypy: Use an almost strict mypy configuration, and fix any issues

pull/111/head
Mike Gerber 1 year ago
parent ad316aeabc
commit 483e809691

@ -28,6 +28,8 @@ repos:
hooks:
- additional_dependencies:
- types-setuptools
- types-lxml
- numpy # for numpy plugin
id: mypy
- repo: https://gitlab.com/vojko.pribudic/pre-commit-update

@ -60,9 +60,20 @@ markers = [
[tool.mypy]
plugins = ["numpy.typing.mypy_plugin"]
ignore_missing_imports = true
strict = true
disallow_subclassing_any = false
# ❗ error: Class cannot subclass "Processor" (has type "Any")
disallow_any_generics = false
disallow_untyped_defs = false
disallow_untyped_calls = false
[tool.ruff]
select = ["E", "F", "I"]
ignore = [

@ -7,5 +7,6 @@ ruff
pytest-ruff
mypy
types-lxml
types-setuptools
pytest-mypy

@ -4,8 +4,7 @@ from math import ceil
from typing import Optional
from rapidfuzz.distance import Levenshtein
from .edit_distance import grapheme_clusters
from uniseg.graphemecluster import grapheme_clusters
def align(t1, t2):

@ -1,5 +1,5 @@
import unicodedata
from typing import List, Tuple
from typing import List, Tuple, TypeVar
from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters
@ -7,6 +7,8 @@ from uniseg.graphemecluster import grapheme_clusters
from .edit_distance import distance
from .extracted_text import ExtractedText
T = TypeVar("T")
@multimethod
def character_error_rate_n(
@ -34,21 +36,24 @@ def character_error_rate_n(
def _(reference: str, compared: str) -> Tuple[float, int]:
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
return character_error_rate_n(seq1, seq2)
cer, n = character_error_rate_n(seq1, seq2)
return cer, n
@character_error_rate_n.register
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return character_error_rate_n(
cer, n = character_error_rate_n(
reference.grapheme_clusters, compared.grapheme_clusters
)
return cer, n
def character_error_rate(reference, compared) -> float:
def character_error_rate(reference: T, compared: T) -> float:
"""
Compute character error rate.
:return: character error rate
"""
cer: float
cer, _ = character_error_rate_n(reference, compared)
return cer

@ -1,5 +1,6 @@
import os
from collections import Counter
from typing import List
import click
from jinja2 import Environment, FileSystemLoader
@ -76,7 +77,7 @@ def gen_diff_report(
if o is not None:
o_pos += len(o)
found_differences = dict(Counter(elem for elem in found_differences))
counted_differences = dict(Counter(elem for elem in found_differences))
return (
"""
@ -87,7 +88,7 @@ def gen_diff_report(
""".format(
gtx, ocrx
),
found_differences,
counted_differences,
)
@ -113,7 +114,7 @@ def process(
metrics: bool = True,
differences: bool = False,
textequiv_level: str = "region",
):
) -> None:
"""Check OCR result against GT.
The @click decorators change the signature of the decorated functions, so we keep
@ -122,8 +123,8 @@ def process(
gt_text = extract(gt, textequiv_level=textequiv_level)
ocr_text = extract(ocr, textequiv_level=textequiv_level)
gt_words: list = list(words_normalized(gt_text))
ocr_words: list = list(words_normalized(ocr_text))
gt_words: List[str] = list(words_normalized(gt_text))
ocr_words: List[str] = list(words_normalized(ocr_text))
assert isinstance(gt_text, ExtractedText)
assert isinstance(ocr_text, ExtractedText)

@ -1,5 +1,6 @@
import json
import os
from typing import Dict
import click
from jinja2 import Environment, FileSystemLoader
@ -13,8 +14,8 @@ def process(reports_folder, occurrences_threshold=1):
wer_list = []
cer_sum = 0
wer_sum = 0
diff_c = {}
diff_w = {}
diff_c: Dict[str, int] = {}
diff_w: Dict[str, int] = {}
for report in os.listdir(reports_folder):
if report.endswith(".json"):

@ -9,7 +9,7 @@ from .extracted_text import ExtractedText
@multimethod
def distance(seq1: List[str], seq2: List[str]):
def distance(seq1: List[str], seq2: List[str]) -> int:
"""Compute the Levenshtein edit distance between two lists of grapheme clusters.
This assumes that the grapheme clusters are already normalized.
@ -20,7 +20,7 @@ def distance(seq1: List[str], seq2: List[str]):
@distance.register
def _(s1: str, s2: str):
def _(s1: str, s2: str) -> int:
"""Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode
@ -33,7 +33,7 @@ def _(s1: str, s2: str):
@distance.register
def _(s1: ExtractedText, s2: ExtractedText):
def _(s1: ExtractedText, s2: ExtractedText) -> int:
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)

@ -4,7 +4,7 @@ import re
import unicodedata
from contextlib import suppress
from itertools import repeat
from typing import List, Optional
from typing import Any, Dict, List, Optional
import attr
import numpy as np
@ -173,10 +173,11 @@ class ExtractedText:
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)
@property
def text(self):
def text(self) -> str:
if self._text is not None:
return self._text
else:
assert self.joiner is not None and self.segments is not None
return self.joiner.join(s.text for s in self.segments)
@functools.cached_property
@ -186,6 +187,7 @@ class ExtractedText:
This property is cached.
"""
assert self.joiner is not None
if len(self.joiner) > 0:
joiner_grapheme_cluster = list(grapheme_clusters(self.joiner))
assert len(joiner_grapheme_cluster) == 1 # see joiner's check above
@ -203,6 +205,7 @@ class ExtractedText:
else:
# TODO Test with text extracted at glyph level (joiner == "")
clusters = []
assert self.segments is not None
for seg in self.segments:
clusters += seg.grapheme_clusters + self._joiner_grapheme_cluster
clusters = clusters[:-1]
@ -218,6 +221,7 @@ class ExtractedText:
else:
# Recurse
segment_id_for_pos = []
assert self.joiner is not None and self.segments is not None
for s in self.segments:
seg_ids = [s.segment_id_for_pos(i) for i in range(len(s.text))]
segment_id_for_pos.extend(seg_ids)
@ -280,7 +284,7 @@ def invert_dict(d):
return {v: k for k, v in d.items()}
def get_textequiv_unicode(text_segment, nsmap) -> str:
def get_textequiv_unicode(text_segment: Any, nsmap: Dict[str, str]) -> str:
"""Get the TextEquiv/Unicode text of the given PAGE text element."""
segment_id = text_segment.attrib["id"]
textequivs = text_segment.findall("./page:TextEquiv", namespaces=nsmap)
@ -304,7 +308,7 @@ def get_first_textequiv(textequivs, segment_id):
if np.any(~nan_mask):
if np.any(nan_mask):
log.warning("TextEquiv without index in %s.", segment_id)
index = np.nanargmin(indices)
index = int(np.nanargmin(indices))
else:
# try ordering by conf
confidences = np.array([get_attr(te, "conf") for te in textequivs], dtype=float)
@ -313,7 +317,7 @@ def get_first_textequiv(textequivs, segment_id):
"No index attributes, use 'conf' attribute to sort TextEquiv in %s.",
segment_id,
)
index = np.nanargmax(confidences)
index = int(np.nanargmax(confidences))
else:
# fallback to first entry in case of neither index or conf present
log.warning("No index attributes, use first TextEquiv in %s.", segment_id)
@ -321,7 +325,7 @@ def get_first_textequiv(textequivs, segment_id):
return textequivs[index]
def get_attr(te, attr_name) -> float:
def get_attr(te: Any, attr_name: str) -> float:
"""Extract the attribute for the given name.
Note: currently only handles numeric values!

@ -1,6 +1,6 @@
import os
import sys
from typing import Iterator
from typing import Dict, Iterator, Optional
import chardet
from lxml import etree as ET
@ -10,11 +10,11 @@ from uniseg.graphemecluster import grapheme_clusters
from .extracted_text import ExtractedText, normalize_sbb
def alto_namespace(tree: ET.ElementTree) -> str:
def alto_namespace(tree: ET._ElementTree) -> Optional[str]:
"""Return the ALTO namespace used in the given ElementTree.
This relies on the assumption that, in any given ALTO file, the root element has the
local name "alto". We do not check if the files uses any valid ALTO namespace.
local name "alto". We do not check if the file uses any valid ALTO namespace.
"""
root_name = ET.QName(tree.getroot().tag)
if root_name.localname == "alto":
@ -23,8 +23,15 @@ def alto_namespace(tree: ET.ElementTree) -> str:
raise ValueError("Not an ALTO tree")
def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
nsmap = {"alto": alto_namespace(tree)}
def alto_nsmap(tree: ET._ElementTree) -> Dict[str, str]:
alto_ns = alto_namespace(tree)
if alto_ns is None:
raise ValueError("Could not determine ALTO namespace")
return {"alto": alto_ns}
def alto_extract_lines(tree: ET._ElementTree) -> Iterator[ExtractedText]:
nsmap = alto_nsmap(tree)
for line in tree.iterfind(".//alto:TextLine", namespaces=nsmap):
line_id = line.attrib.get("ID")
line_text = " ".join(
@ -37,7 +44,7 @@ def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
# FIXME hardcoded SBB normalization
def alto_extract(tree: ET.ElementTree) -> ExtractedText:
def alto_extract(tree: ET._ElementTree) -> ExtractedText:
"""Extract text from the given ALTO ElementTree."""
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None)
@ -98,7 +105,7 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
if ET.QName(group.tag).localname in ["OrderedGroup", "OrderedGroupIndexed"]:
ro_children = list(group)
ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children)
ro_children = [child for child in ro_children if "index" in child.attrib.keys()]
ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"]))
elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]:
ro_children = list(group)

@ -1,5 +1,5 @@
import unicodedata
from typing import Iterable, Tuple
from typing import Generator, Iterable, Tuple, TypeVar
import uniseg.wordbreak
from multimethod import multimethod
@ -7,6 +7,8 @@ from rapidfuzz.distance import Levenshtein
from .extracted_text import ExtractedText
T = TypeVar("T")
# Did we patch uniseg.wordbreak.word_break already?
word_break_patched = False
@ -32,7 +34,7 @@ def patch_word_break():
@multimethod
def words(s: str):
def words(s: str) -> Generator[str, None, None]:
"""Extract words from a string"""
global word_break_patched
@ -61,34 +63,36 @@ def words(s: str):
@words.register
def _(s: ExtractedText):
return words(s.text)
def _(s: ExtractedText) -> Generator[str, None, None]:
yield from words(s.text)
@multimethod
def words_normalized(s: str):
return words(unicodedata.normalize("NFC", s))
def words_normalized(s: str) -> Generator[str, None, None]:
yield from words(unicodedata.normalize("NFC", s))
@words_normalized.register
def _(s: ExtractedText):
return words_normalized(s.text)
def _(s: ExtractedText) -> Generator[str, None, None]:
yield from words_normalized(s.text)
@multimethod
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
return word_error_rate_n(reference_seq, compared_seq)
wer, n = word_error_rate_n(reference_seq, compared_seq)
return wer, n
@word_error_rate_n.register
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text)
wer, n = word_error_rate_n(reference.text, compared.text)
return wer, n
@word_error_rate_n.register
def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)
@ -102,6 +106,7 @@ def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
return d / n, n
def word_error_rate(reference, compared) -> float:
def word_error_rate(reference: T, compared: T) -> float:
wer: float
wer, _ = word_error_rate_n(reference, compared)
return wer

Loading…
Cancel
Save