diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3ea6e96..8c25236 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,8 @@ repos: hooks: - additional_dependencies: - types-setuptools + - types-lxml + - numpy # for numpy plugin id: mypy - repo: https://gitlab.com/vojko.pribudic/pre-commit-update diff --git a/pyproject.toml b/pyproject.toml index ce32d56..05075e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,9 +60,20 @@ markers = [ [tool.mypy] +plugins = ["numpy.typing.mypy_plugin"] + ignore_missing_imports = true +strict = true + +disallow_subclassing_any = false +# ❗ error: Class cannot subclass "Processor" (has type "Any") +disallow_any_generics = false +disallow_untyped_defs = false +disallow_untyped_calls = false + + [tool.ruff] select = ["E", "F", "I"] ignore = [ diff --git a/requirements-dev.txt b/requirements-dev.txt index de6003d..16ae880 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,5 +7,6 @@ ruff pytest-ruff mypy +types-lxml types-setuptools pytest-mypy diff --git a/src/dinglehopper/align.py b/src/dinglehopper/align.py index c5f12f7..5d1f290 100644 --- a/src/dinglehopper/align.py +++ b/src/dinglehopper/align.py @@ -4,8 +4,7 @@ from math import ceil from typing import Optional from rapidfuzz.distance import Levenshtein - -from .edit_distance import grapheme_clusters +from uniseg.graphemecluster import grapheme_clusters def align(t1, t2): diff --git a/src/dinglehopper/character_error_rate.py b/src/dinglehopper/character_error_rate.py index 35d3b07..88a88f8 100644 --- a/src/dinglehopper/character_error_rate.py +++ b/src/dinglehopper/character_error_rate.py @@ -1,5 +1,5 @@ import unicodedata -from typing import List, Tuple +from typing import List, Tuple, TypeVar from multimethod import multimethod from uniseg.graphemecluster import grapheme_clusters @@ -7,6 +7,8 @@ from uniseg.graphemecluster import grapheme_clusters from .edit_distance import distance from .extracted_text import ExtractedText +T = TypeVar("T") + @multimethod def character_error_rate_n( @@ -34,21 +36,24 @@ def character_error_rate_n( def _(reference: str, compared: str) -> Tuple[float, int]: seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference))) seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared))) - return character_error_rate_n(seq1, seq2) + cer, n = character_error_rate_n(seq1, seq2) + return cer, n @character_error_rate_n.register def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]: - return character_error_rate_n( + cer, n = character_error_rate_n( reference.grapheme_clusters, compared.grapheme_clusters ) + return cer, n -def character_error_rate(reference, compared) -> float: +def character_error_rate(reference: T, compared: T) -> float: """ Compute character error rate. :return: character error rate """ + cer: float cer, _ = character_error_rate_n(reference, compared) return cer diff --git a/src/dinglehopper/cli.py b/src/dinglehopper/cli.py index 5d2000a..a58a2af 100644 --- a/src/dinglehopper/cli.py +++ b/src/dinglehopper/cli.py @@ -1,5 +1,6 @@ import os from collections import Counter +from typing import List import click from jinja2 import Environment, FileSystemLoader @@ -76,7 +77,7 @@ def gen_diff_report( if o is not None: o_pos += len(o) - found_differences = dict(Counter(elem for elem in found_differences)) + counted_differences = dict(Counter(elem for elem in found_differences)) return ( """ @@ -87,7 +88,7 @@ def gen_diff_report( """.format( gtx, ocrx ), - found_differences, + counted_differences, ) @@ -113,7 +114,7 @@ def process( metrics: bool = True, differences: bool = False, textequiv_level: str = "region", -): +) -> None: """Check OCR result against GT. The @click decorators change the signature of the decorated functions, so we keep @@ -122,8 +123,8 @@ def process( gt_text = extract(gt, textequiv_level=textequiv_level) ocr_text = extract(ocr, textequiv_level=textequiv_level) - gt_words: list = list(words_normalized(gt_text)) - ocr_words: list = list(words_normalized(ocr_text)) + gt_words: List[str] = list(words_normalized(gt_text)) + ocr_words: List[str] = list(words_normalized(ocr_text)) assert isinstance(gt_text, ExtractedText) assert isinstance(ocr_text, ExtractedText) diff --git a/src/dinglehopper/cli_summarize.py b/src/dinglehopper/cli_summarize.py index e0c20cb..c49911b 100644 --- a/src/dinglehopper/cli_summarize.py +++ b/src/dinglehopper/cli_summarize.py @@ -1,5 +1,6 @@ import json import os +from typing import Dict import click from jinja2 import Environment, FileSystemLoader @@ -13,8 +14,8 @@ def process(reports_folder, occurrences_threshold=1): wer_list = [] cer_sum = 0 wer_sum = 0 - diff_c = {} - diff_w = {} + diff_c: Dict[str, int] = {} + diff_w: Dict[str, int] = {} for report in os.listdir(reports_folder): if report.endswith(".json"): diff --git a/src/dinglehopper/edit_distance.py b/src/dinglehopper/edit_distance.py index ac4a847..ec564ae 100644 --- a/src/dinglehopper/edit_distance.py +++ b/src/dinglehopper/edit_distance.py @@ -9,7 +9,7 @@ from .extracted_text import ExtractedText @multimethod -def distance(seq1: List[str], seq2: List[str]): +def distance(seq1: List[str], seq2: List[str]) -> int: """Compute the Levenshtein edit distance between two lists of grapheme clusters. This assumes that the grapheme clusters are already normalized. @@ -20,7 +20,7 @@ def distance(seq1: List[str], seq2: List[str]): @distance.register -def _(s1: str, s2: str): +def _(s1: str, s2: str) -> int: """Compute the Levenshtein edit distance between two Unicode strings Note that this is different from levenshtein() as this function knows about Unicode @@ -33,7 +33,7 @@ def _(s1: str, s2: str): @distance.register -def _(s1: ExtractedText, s2: ExtractedText): +def _(s1: ExtractedText, s2: ExtractedText) -> int: return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters) diff --git a/src/dinglehopper/extracted_text.py b/src/dinglehopper/extracted_text.py index af54d7c..e4b0915 100644 --- a/src/dinglehopper/extracted_text.py +++ b/src/dinglehopper/extracted_text.py @@ -4,7 +4,7 @@ import re import unicodedata from contextlib import suppress from itertools import repeat -from typing import List, Optional +from typing import Any, Dict, List, Optional import attr import numpy as np @@ -173,10 +173,11 @@ class ExtractedText: normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB) @property - def text(self): + def text(self) -> str: if self._text is not None: return self._text else: + assert self.joiner is not None and self.segments is not None return self.joiner.join(s.text for s in self.segments) @functools.cached_property @@ -186,6 +187,7 @@ class ExtractedText: This property is cached. """ + assert self.joiner is not None if len(self.joiner) > 0: joiner_grapheme_cluster = list(grapheme_clusters(self.joiner)) assert len(joiner_grapheme_cluster) == 1 # see joiner's check above @@ -203,6 +205,7 @@ class ExtractedText: else: # TODO Test with text extracted at glyph level (joiner == "") clusters = [] + assert self.segments is not None for seg in self.segments: clusters += seg.grapheme_clusters + self._joiner_grapheme_cluster clusters = clusters[:-1] @@ -218,6 +221,7 @@ class ExtractedText: else: # Recurse segment_id_for_pos = [] + assert self.joiner is not None and self.segments is not None for s in self.segments: seg_ids = [s.segment_id_for_pos(i) for i in range(len(s.text))] segment_id_for_pos.extend(seg_ids) @@ -280,7 +284,7 @@ def invert_dict(d): return {v: k for k, v in d.items()} -def get_textequiv_unicode(text_segment, nsmap) -> str: +def get_textequiv_unicode(text_segment: Any, nsmap: Dict[str, str]) -> str: """Get the TextEquiv/Unicode text of the given PAGE text element.""" segment_id = text_segment.attrib["id"] textequivs = text_segment.findall("./page:TextEquiv", namespaces=nsmap) @@ -304,7 +308,7 @@ def get_first_textequiv(textequivs, segment_id): if np.any(~nan_mask): if np.any(nan_mask): log.warning("TextEquiv without index in %s.", segment_id) - index = np.nanargmin(indices) + index = int(np.nanargmin(indices)) else: # try ordering by conf confidences = np.array([get_attr(te, "conf") for te in textequivs], dtype=float) @@ -313,7 +317,7 @@ def get_first_textequiv(textequivs, segment_id): "No index attributes, use 'conf' attribute to sort TextEquiv in %s.", segment_id, ) - index = np.nanargmax(confidences) + index = int(np.nanargmax(confidences)) else: # fallback to first entry in case of neither index or conf present log.warning("No index attributes, use first TextEquiv in %s.", segment_id) @@ -321,7 +325,7 @@ def get_first_textequiv(textequivs, segment_id): return textequivs[index] -def get_attr(te, attr_name) -> float: +def get_attr(te: Any, attr_name: str) -> float: """Extract the attribute for the given name. Note: currently only handles numeric values! diff --git a/src/dinglehopper/ocr_files.py b/src/dinglehopper/ocr_files.py index be66719..f9bd977 100644 --- a/src/dinglehopper/ocr_files.py +++ b/src/dinglehopper/ocr_files.py @@ -1,6 +1,6 @@ import os import sys -from typing import Iterator +from typing import Dict, Iterator, Optional import chardet from lxml import etree as ET @@ -10,11 +10,11 @@ from uniseg.graphemecluster import grapheme_clusters from .extracted_text import ExtractedText, normalize_sbb -def alto_namespace(tree: ET.ElementTree) -> str: +def alto_namespace(tree: ET._ElementTree) -> Optional[str]: """Return the ALTO namespace used in the given ElementTree. This relies on the assumption that, in any given ALTO file, the root element has the - local name "alto". We do not check if the files uses any valid ALTO namespace. + local name "alto". We do not check if the file uses any valid ALTO namespace. """ root_name = ET.QName(tree.getroot().tag) if root_name.localname == "alto": @@ -23,8 +23,15 @@ def alto_namespace(tree: ET.ElementTree) -> str: raise ValueError("Not an ALTO tree") -def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]: - nsmap = {"alto": alto_namespace(tree)} +def alto_nsmap(tree: ET._ElementTree) -> Dict[str, str]: + alto_ns = alto_namespace(tree) + if alto_ns is None: + raise ValueError("Could not determine ALTO namespace") + return {"alto": alto_ns} + + +def alto_extract_lines(tree: ET._ElementTree) -> Iterator[ExtractedText]: + nsmap = alto_nsmap(tree) for line in tree.iterfind(".//alto:TextLine", namespaces=nsmap): line_id = line.attrib.get("ID") line_text = " ".join( @@ -37,7 +44,7 @@ def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]: # FIXME hardcoded SBB normalization -def alto_extract(tree: ET.ElementTree) -> ExtractedText: +def alto_extract(tree: ET._ElementTree) -> ExtractedText: """Extract text from the given ALTO ElementTree.""" return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None) @@ -98,7 +105,7 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level): if ET.QName(group.tag).localname in ["OrderedGroup", "OrderedGroupIndexed"]: ro_children = list(group) - ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children) + ro_children = [child for child in ro_children if "index" in child.attrib.keys()] ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"])) elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]: ro_children = list(group) diff --git a/src/dinglehopper/word_error_rate.py b/src/dinglehopper/word_error_rate.py index afb4fe0..b6e0a3a 100644 --- a/src/dinglehopper/word_error_rate.py +++ b/src/dinglehopper/word_error_rate.py @@ -1,5 +1,5 @@ import unicodedata -from typing import Iterable, Tuple +from typing import Generator, Iterable, Tuple, TypeVar import uniseg.wordbreak from multimethod import multimethod @@ -7,6 +7,8 @@ from rapidfuzz.distance import Levenshtein from .extracted_text import ExtractedText +T = TypeVar("T") + # Did we patch uniseg.wordbreak.word_break already? word_break_patched = False @@ -32,7 +34,7 @@ def patch_word_break(): @multimethod -def words(s: str): +def words(s: str) -> Generator[str, None, None]: """Extract words from a string""" global word_break_patched @@ -61,34 +63,36 @@ def words(s: str): @words.register -def _(s: ExtractedText): - return words(s.text) +def _(s: ExtractedText) -> Generator[str, None, None]: + yield from words(s.text) @multimethod -def words_normalized(s: str): - return words(unicodedata.normalize("NFC", s)) +def words_normalized(s: str) -> Generator[str, None, None]: + yield from words(unicodedata.normalize("NFC", s)) @words_normalized.register -def _(s: ExtractedText): - return words_normalized(s.text) +def _(s: ExtractedText) -> Generator[str, None, None]: + yield from words_normalized(s.text) @multimethod def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: reference_seq = list(words_normalized(reference)) compared_seq = list(words_normalized(compared)) - return word_error_rate_n(reference_seq, compared_seq) + wer, n = word_error_rate_n(reference_seq, compared_seq) + return wer, n @word_error_rate_n.register def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]: - return word_error_rate_n(reference.text, compared.text) + wer, n = word_error_rate_n(reference.text, compared.text) + return wer, n @word_error_rate_n.register -def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]: +def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]: reference_seq = list(reference) compared_seq = list(compared) @@ -102,6 +106,7 @@ def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]: return d / n, n -def word_error_rate(reference, compared) -> float: +def word_error_rate(reference: T, compared: T) -> float: + wer: float wer, _ = word_error_rate_n(reference, compared) return wer