From d609a532bf1714b9b41c9b5a7454bce0c44c434f Mon Sep 17 00:00:00 2001 From: kba Date: Mon, 20 Oct 2025 19:46:07 +0200 Subject: [PATCH] organize imports mostly --- src/eynollah/eynollah.py | 39 +++++++++--------------------------- src/eynollah/eynollah_ocr.py | 39 ++++++++++++++++++++++++++++++++++++ src/eynollah/model_zoo.py | 26 ++++++++++++++++++++---- 3 files changed, 71 insertions(+), 33 deletions(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 0a7b660..f281ac6 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -8,40 +8,26 @@ document layout analysis (segmentation) with output in PAGE-XML """ -# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files -import sys - -if sys.version_info < (3, 10): - import importlib_resources -else: - import importlib.resources as importlib_resources - from difflib import SequenceMatcher as sq from PIL import Image, ImageDraw, ImageFont import math import os -import sys import time from typing import Dict, Union,List, Optional, Tuple -import atexit import warnings from functools import partial from pathlib import Path from multiprocessing import cpu_count import gc import copy -import json from concurrent.futures import ProcessPoolExecutor -import xml.etree.ElementTree as ET import cv2 import numpy as np import shapely.affinity from scipy.signal import find_peaks from scipy.ndimage import gaussian_filter1d -from numba import cuda from skimage.morphology import skeletonize -from ocrd import OcrdPage from ocrd_utils import getLogger, tf_disable_interactive_logs import statistics @@ -53,10 +39,6 @@ try: import matplotlib.pyplot as plt except ImportError: plt = None -try: - from transformers import TrOCRProcessor, VisionEncoderDecoderModel -except ImportError: - TrOCRProcessor = VisionEncoderDecoderModel = None #os.environ['CUDA_VISIBLE_DEVICES'] = '-1' tf_disable_interactive_logs() @@ -290,13 +272,6 @@ class Eynollah: if self.tr: loadable.append(('ocr', 'tr')) loadable.append(('ocr_tr_processor', 'tr')) - # TODO why here and why only for tr? - if torch.cuda.is_available(): - self.logger.info("Using GPU acceleration") - self.device = torch.device("cuda:0") - else: - self.logger.info("Using CPU processing") - self.device = torch.device("cpu") else: loadable.append('ocr') loadable.append('num_to_char') @@ -307,10 +282,16 @@ class Eynollah: if hasattr(self, 'executor') and getattr(self, 'executor'): self.executor.shutdown() self.executor = None - if hasattr(self, 'models') and getattr(self, 'models'): - for model_name in list(self.models): - if self.models[model_name]: - del self.models[model_name] + self.model_zoo.shutdown() + + @property + def device(self): + # TODO why here and why only for tr? + if torch.cuda.is_available(): + self.logger.info("Using GPU acceleration") + return torch.device("cuda:0") + self.logger.info("Using CPU processing") + return torch.device("cpu") def cache_images(self, image_filename=None, image_pil=None, dpi=None): ret = {} diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 19825c5..6adea55 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -1,3 +1,41 @@ +# pyright: reportPossiblyUnboundVariable=false + +from logging import getLogger +from typing import Optional +from pathlib import Path +import os +import json +import gc +import sys +import math +import cv2 +import time + +from keras.layers import StringLookup + +from eynollah.utils.resize import resize_image +from eynollah.utils.utils_ocr import break_curved_line_into_small_pieces_and_then_merge, decode_batch_predictions, fit_text_single_line, get_contours_and_bounding_boxes, get_orientation_moments, preprocess_and_resize_image_for_ocrcnn_model, return_textlines_split_if_needed, rotate_image_with_padding + +from .utils import is_image_filename + +import xml.etree.ElementTree as ET +import tensorflow as tf +from keras.models import load_model +from PIL import Image, ImageDraw, ImageFont +import numpy as np +import torch + +# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files +if sys.version_info < (3, 10): + import importlib_resources +else: + import importlib.resources as importlib_resources + +try: + from transformers import TrOCRProcessor, VisionEncoderDecoderModel +except ImportError: + TrOCRProcessor = VisionEncoderDecoderModel = None + class Eynollah_ocr: def __init__( self, @@ -25,6 +63,7 @@ class Eynollah_ocr: else: self.min_conf_value_of_textline_text = 0.3 if tr_ocr: + assert TrOCRProcessor self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if self.model_name: diff --git a/src/eynollah/model_zoo.py b/src/eynollah/model_zoo.py index b332b4a..ee8b6b0 100644 --- a/src/eynollah/model_zoo.py +++ b/src/eynollah/model_zoo.py @@ -2,7 +2,6 @@ from dataclasses import dataclass import json import logging from pathlib import Path -from types import MappingProxyType from typing import Dict, Literal, Optional, Tuple, List, Union from copy import deepcopy @@ -12,6 +11,8 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel from eynollah.patch_encoder import PatchEncoder, Patches +SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model] + # Dict mapping model_category to dict mapping variant (default is '') to Path DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = { @@ -134,13 +135,14 @@ class EynollahModelZoo(): def __init__( self, basedir: str, - model_overrides: List[Tuple[str, str, str]], + model_overrides: Optional[List[Tuple[str, str, str]]]=None, ) -> None: self.model_basedir = Path(basedir) self.logger = logging.getLogger('eynollah.model_zoo') self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS) if model_overrides: self.override_models(*model_overrides) + self._loaded: Dict[Tuple[str, str], SomeEynollahModel] = {} def override_models(self, *model_overrides: Tuple[str, str, str]): """ @@ -202,7 +204,7 @@ class EynollahModelZoo(): model_category: str, model_variant: str = '', model_filename: str = '', - ) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]: + ) -> SomeEynollahModel: """ Load any model """ @@ -223,9 +225,16 @@ class EynollahModelZoo(): self.logger.exception(e) model = load_model(model_path, compile=False, custom_objects={ "PatchEncoder": PatchEncoder, "Patches": Patches}) + self._loaded[(model_category, model_variant)] = model return model # type: ignore - def _load_ocr_model(self, variant: str) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]: + def get_model(self, model_categeory, model_variant) -> SomeEynollahModel: + needle = (model_categeory, model_variant) + if needle not in self._loaded: + raise ValueError('Model/variant "{needle} not previously loaded with "load_model(..)"') + return self._loaded[needle] + + def _load_ocr_model(self, variant: str) -> SomeEynollahModel: """ Load OCR model """ @@ -258,3 +267,12 @@ class EynollahModelZoo(): 'versions': self.model_versions, }, indent=2)) + def shutdown(self): + """ + Ensure that a loaded models is not referenced by ``self._loaded`` anymore + """ + if hasattr(self, '_loaded') and getattr(self, '_loaded'): + for needle in self._loaded: + if self._loaded[needle]: + del self._loaded[needle] +