typing and asserts

This commit is contained in:
kba 2025-10-21 12:05:27 +02:00
parent 44b75eb36f
commit c6b863b13f
4 changed files with 30 additions and 23 deletions

View file

@ -9,11 +9,10 @@ document layout analysis (segmentation) with output in PAGE-XML
"""
from difflib import SequenceMatcher as sq
from PIL import Image, ImageDraw, ImageFont
import math
import os
import time
from typing import Dict, Union,List, Optional, Tuple
from typing import Dict, Type, Union,List, Optional, Tuple
import warnings
from functools import partial
from pathlib import Path
@ -32,7 +31,7 @@ from ocrd_utils import getLogger, tf_disable_interactive_logs
import statistics
try:
import torch
import torch # type: ignore
except ImportError:
torch = None
try:
@ -43,13 +42,8 @@ except ImportError:
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
tf_disable_interactive_logs()
import tensorflow as tf
from keras.models import load_model
tf.get_logger().setLevel("ERROR")
warnings.filterwarnings("ignore")
# use tf1 compatibility for keras backend
from tensorflow.compat.v1.keras.backend import set_session
from tensorflow.keras import layers
from tensorflow.keras.layers import StringLookup
from .model_zoo import EynollahModelZoo
from .utils.contour import (
@ -280,6 +274,7 @@ class Eynollah:
def __del__(self):
if hasattr(self, 'executor') and getattr(self, 'executor'):
assert self.executor
self.executor.shutdown()
self.executor = None
self.model_zoo.shutdown()
@ -287,6 +282,7 @@ class Eynollah:
@property
def device(self):
# TODO why here and why only for tr?
assert torch
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
return torch.device("cuda:0")
@ -689,8 +685,8 @@ class Eynollah:
self.img_hight_int = int(self.image.shape[0] * scale)
self.img_width_int = int(self.image.shape[1] * scale)
self.scale_y = self.img_hight_int / float(self.image.shape[0])
self.scale_x = self.img_width_int / float(self.image.shape[1])
self.scale_y: float = self.img_hight_int / float(self.image.shape[0])
self.scale_x: float = self.img_width_int / float(self.image.shape[1])
self.image = resize_image(self.image, self.img_hight_int, self.img_width_int)
@ -1755,6 +1751,7 @@ class Eynollah:
return [], [], []
self.logger.debug("enter get_slopes_and_deskew_new_light")
with share_ndarray(textline_mask_tot) as textline_mask_tot_shared:
assert self.executor
results = self.executor.map(partial(do_work_of_slopes_new_light,
textline_mask_tot_ea=textline_mask_tot_shared,
slope_deskew=slope_deskew,
@ -1771,6 +1768,7 @@ class Eynollah:
return [], [], []
self.logger.debug("enter get_slopes_and_deskew_new")
with share_ndarray(textline_mask_tot) as textline_mask_tot_shared:
assert self.executor
results = self.executor.map(partial(do_work_of_slopes_new,
textline_mask_tot_ea=textline_mask_tot_shared,
slope_deskew=slope_deskew,
@ -1791,6 +1789,7 @@ class Eynollah:
self.logger.debug("enter get_slopes_and_deskew_new_curved")
with share_ndarray(textline_mask_tot) as textline_mask_tot_shared:
with share_ndarray(mask_texts_only) as mask_texts_only_shared:
assert self.executor
results = self.executor.map(partial(do_work_of_slopes_new_curved,
textline_mask_tot_ea=textline_mask_tot_shared,
mask_texts_only=mask_texts_only_shared,

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
import json
import logging
from pathlib import Path
from typing import Dict, Literal, Optional, Tuple, List, Union
from typing import Dict, Literal, Optional, Tuple, List, Type, TypeVar, Union
from copy import deepcopy
from keras.layers import StringLookup
@ -12,7 +12,7 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from eynollah.patch_encoder import PatchEncoder, Patches
SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model, List]
T = TypeVar('T')
# Dict mapping model_category to dict mapping variant (default is '') to Path
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
@ -149,7 +149,10 @@ class EynollahModelZoo():
self.override_models(*model_overrides)
self._loaded: Dict[str, SomeEynollahModel] = {}
def override_models(self, *model_overrides: Tuple[str, str, str]):
def override_models(
self,
*model_overrides: Tuple[str, str, str],
):
"""
Override the default model versions
"""
@ -235,10 +238,13 @@ class EynollahModelZoo():
self._loaded[model_category] = model
return model # type: ignore
def get(self, model_category) -> SomeEynollahModel:
def get(self, model_category: str, model_type: Optional[Type[T]]=None) -> T:
if model_category not in self._loaded:
raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"')
return self._loaded[model_category]
ret = self._loaded[model_category]
if model_type:
assert isinstance(ret, model_type)
return ret # type: ignore # FIXME: convince typing that we're returning generic type
def _load_ocr_model(self, variant: str) -> SomeEynollahModel:
"""

View file

@ -40,8 +40,8 @@ class EynollahPlotter:
self.image_filename_stem = image_filename_stem
# XXX TODO hacky these cannot be set at init time
self.image_org = image_org
self.scale_x = scale_x
self.scale_y = scale_y
self.scale_x : float = scale_x
self.scale_y : float = scale_y
def save_plot_of_layout_main(self, text_regions_p, image_page):
if self.dir_of_layout is not None:

View file

@ -2,7 +2,7 @@
# pylint: disable=import-error
from pathlib import Path
import os.path
import xml.etree.ElementTree as ET
from typing import Optional
from .utils.xml import create_page_xml, xml_reading_order
from .utils.counter import EynollahIdCounter
@ -10,7 +10,6 @@ from ocrd_utils import getLogger
from ocrd_models.ocrd_page import (
BorderType,
CoordsType,
PcGtsType,
TextLineType,
TextEquivType,
TextRegionType,
@ -32,10 +31,10 @@ class EynollahXmlWriter:
self.curved_line = curved_line
self.textline_light = textline_light
self.pcgts = pcgts
self.scale_x = None # XXX set outside __init__
self.scale_y = None # XXX set outside __init__
self.height_org = None # XXX set outside __init__
self.width_org = None # XXX set outside __init__
self.scale_x: Optional[float] = None # XXX set outside __init__
self.scale_y: Optional[float] = None # XXX set outside __init__
self.height_org: Optional[int] = None # XXX set outside __init__
self.width_org: Optional[int] = None # XXX set outside __init__
@property
def image_filename_stem(self):
@ -135,6 +134,7 @@ class EynollahXmlWriter:
# create the file structure
pcgts = self.pcgts if self.pcgts else create_page_xml(self.image_filename, self.height_org, self.width_org)
page = pcgts.get_Page()
assert page
page.set_Border(BorderType(Coords=CoordsType(points=self.calculate_page_coords(cont_page))))
counter = EynollahIdCounter()
@ -152,6 +152,7 @@ class EynollahXmlWriter:
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord,
skip_layout_reading_order))
)
assert textregion.Coords
if conf_contours_textregions:
textregion.Coords.set_conf(conf_contours_textregions[mm])
page.add_TextRegion(textregion)
@ -168,6 +169,7 @@ class EynollahXmlWriter:
id=counter.next_region_id, type_='heading',
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord))
)
assert textregion.Coords
if conf_contours_textregions_h:
textregion.Coords.set_conf(conf_contours_textregions_h[mm])
page.add_TextRegion(textregion)