From c6b863b13f31eaa2b0dc68460e75c80230b2a0fe Mon Sep 17 00:00:00 2001 From: kba Date: Tue, 21 Oct 2025 12:05:27 +0200 Subject: [PATCH] typing and asserts --- src/eynollah/eynollah.py | 19 +++++++++---------- src/eynollah/model_zoo.py | 16 +++++++++++----- src/eynollah/plot.py | 4 ++-- src/eynollah/writer.py | 14 ++++++++------ 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 3582c67..6356198 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -9,11 +9,10 @@ document layout analysis (segmentation) with output in PAGE-XML """ from difflib import SequenceMatcher as sq -from PIL import Image, ImageDraw, ImageFont import math import os import time -from typing import Dict, Union,List, Optional, Tuple +from typing import Dict, Type, Union,List, Optional, Tuple import warnings from functools import partial from pathlib import Path @@ -32,7 +31,7 @@ from ocrd_utils import getLogger, tf_disable_interactive_logs import statistics try: - import torch + import torch # type: ignore except ImportError: torch = None try: @@ -43,13 +42,8 @@ except ImportError: #os.environ['CUDA_VISIBLE_DEVICES'] = '-1' tf_disable_interactive_logs() import tensorflow as tf -from keras.models import load_model tf.get_logger().setLevel("ERROR") warnings.filterwarnings("ignore") -# use tf1 compatibility for keras backend -from tensorflow.compat.v1.keras.backend import set_session -from tensorflow.keras import layers -from tensorflow.keras.layers import StringLookup from .model_zoo import EynollahModelZoo from .utils.contour import ( @@ -280,6 +274,7 @@ class Eynollah: def __del__(self): if hasattr(self, 'executor') and getattr(self, 'executor'): + assert self.executor self.executor.shutdown() self.executor = None self.model_zoo.shutdown() @@ -287,6 +282,7 @@ class Eynollah: @property def device(self): # TODO why here and why only for tr? + assert torch if torch.cuda.is_available(): self.logger.info("Using GPU acceleration") return torch.device("cuda:0") @@ -689,8 +685,8 @@ class Eynollah: self.img_hight_int = int(self.image.shape[0] * scale) self.img_width_int = int(self.image.shape[1] * scale) - self.scale_y = self.img_hight_int / float(self.image.shape[0]) - self.scale_x = self.img_width_int / float(self.image.shape[1]) + self.scale_y: float = self.img_hight_int / float(self.image.shape[0]) + self.scale_x: float = self.img_width_int / float(self.image.shape[1]) self.image = resize_image(self.image, self.img_hight_int, self.img_width_int) @@ -1755,6 +1751,7 @@ class Eynollah: return [], [], [] self.logger.debug("enter get_slopes_and_deskew_new_light") with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: + assert self.executor results = self.executor.map(partial(do_work_of_slopes_new_light, textline_mask_tot_ea=textline_mask_tot_shared, slope_deskew=slope_deskew, @@ -1771,6 +1768,7 @@ class Eynollah: return [], [], [] self.logger.debug("enter get_slopes_and_deskew_new") with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: + assert self.executor results = self.executor.map(partial(do_work_of_slopes_new, textline_mask_tot_ea=textline_mask_tot_shared, slope_deskew=slope_deskew, @@ -1791,6 +1789,7 @@ class Eynollah: self.logger.debug("enter get_slopes_and_deskew_new_curved") with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: with share_ndarray(mask_texts_only) as mask_texts_only_shared: + assert self.executor results = self.executor.map(partial(do_work_of_slopes_new_curved, textline_mask_tot_ea=textline_mask_tot_shared, mask_texts_only=mask_texts_only_shared, diff --git a/src/eynollah/model_zoo.py b/src/eynollah/model_zoo.py index b92d4f1..7f90bc0 100644 --- a/src/eynollah/model_zoo.py +++ b/src/eynollah/model_zoo.py @@ -2,7 +2,7 @@ from dataclasses import dataclass import json import logging from pathlib import Path -from typing import Dict, Literal, Optional, Tuple, List, Union +from typing import Dict, Literal, Optional, Tuple, List, Type, TypeVar, Union from copy import deepcopy from keras.layers import StringLookup @@ -12,7 +12,7 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel from eynollah.patch_encoder import PatchEncoder, Patches SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model, List] - +T = TypeVar('T') # Dict mapping model_category to dict mapping variant (default is '') to Path DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = { @@ -149,7 +149,10 @@ class EynollahModelZoo(): self.override_models(*model_overrides) self._loaded: Dict[str, SomeEynollahModel] = {} - def override_models(self, *model_overrides: Tuple[str, str, str]): + def override_models( + self, + *model_overrides: Tuple[str, str, str], + ): """ Override the default model versions """ @@ -235,10 +238,13 @@ class EynollahModelZoo(): self._loaded[model_category] = model return model # type: ignore - def get(self, model_category) -> SomeEynollahModel: + def get(self, model_category: str, model_type: Optional[Type[T]]=None) -> T: if model_category not in self._loaded: raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"') - return self._loaded[model_category] + ret = self._loaded[model_category] + if model_type: + assert isinstance(ret, model_type) + return ret # type: ignore # FIXME: convince typing that we're returning generic type def _load_ocr_model(self, variant: str) -> SomeEynollahModel: """ diff --git a/src/eynollah/plot.py b/src/eynollah/plot.py index c026e94..b1b2359 100644 --- a/src/eynollah/plot.py +++ b/src/eynollah/plot.py @@ -40,8 +40,8 @@ class EynollahPlotter: self.image_filename_stem = image_filename_stem # XXX TODO hacky these cannot be set at init time self.image_org = image_org - self.scale_x = scale_x - self.scale_y = scale_y + self.scale_x : float = scale_x + self.scale_y : float = scale_y def save_plot_of_layout_main(self, text_regions_p, image_page): if self.dir_of_layout is not None: diff --git a/src/eynollah/writer.py b/src/eynollah/writer.py index 9c3456a..a0ec077 100644 --- a/src/eynollah/writer.py +++ b/src/eynollah/writer.py @@ -2,7 +2,7 @@ # pylint: disable=import-error from pathlib import Path import os.path -import xml.etree.ElementTree as ET +from typing import Optional from .utils.xml import create_page_xml, xml_reading_order from .utils.counter import EynollahIdCounter @@ -10,7 +10,6 @@ from ocrd_utils import getLogger from ocrd_models.ocrd_page import ( BorderType, CoordsType, - PcGtsType, TextLineType, TextEquivType, TextRegionType, @@ -32,10 +31,10 @@ class EynollahXmlWriter: self.curved_line = curved_line self.textline_light = textline_light self.pcgts = pcgts - self.scale_x = None # XXX set outside __init__ - self.scale_y = None # XXX set outside __init__ - self.height_org = None # XXX set outside __init__ - self.width_org = None # XXX set outside __init__ + self.scale_x: Optional[float] = None # XXX set outside __init__ + self.scale_y: Optional[float] = None # XXX set outside __init__ + self.height_org: Optional[int] = None # XXX set outside __init__ + self.width_org: Optional[int] = None # XXX set outside __init__ @property def image_filename_stem(self): @@ -135,6 +134,7 @@ class EynollahXmlWriter: # create the file structure pcgts = self.pcgts if self.pcgts else create_page_xml(self.image_filename, self.height_org, self.width_org) page = pcgts.get_Page() + assert page page.set_Border(BorderType(Coords=CoordsType(points=self.calculate_page_coords(cont_page)))) counter = EynollahIdCounter() @@ -152,6 +152,7 @@ class EynollahXmlWriter: Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord, skip_layout_reading_order)) ) + assert textregion.Coords if conf_contours_textregions: textregion.Coords.set_conf(conf_contours_textregions[mm]) page.add_TextRegion(textregion) @@ -168,6 +169,7 @@ class EynollahXmlWriter: id=counter.next_region_id, type_='heading', Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord)) ) + assert textregion.Coords if conf_contours_textregions_h: textregion.Coords.set_conf(conf_contours_textregions_h[mm]) page.add_TextRegion(textregion)