organize imports mostly

This commit is contained in:
kba 2025-10-20 19:46:07 +02:00
parent 48d1198d24
commit d609a532bf
3 changed files with 71 additions and 33 deletions

View file

@ -8,40 +8,26 @@
document layout analysis (segmentation) with output in PAGE-XML
"""
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
import sys
if sys.version_info < (3, 10):
import importlib_resources
else:
import importlib.resources as importlib_resources
from difflib import SequenceMatcher as sq
from PIL import Image, ImageDraw, ImageFont
import math
import os
import sys
import time
from typing import Dict, Union,List, Optional, Tuple
import atexit
import warnings
from functools import partial
from pathlib import Path
from multiprocessing import cpu_count
import gc
import copy
import json
from concurrent.futures import ProcessPoolExecutor
import xml.etree.ElementTree as ET
import cv2
import numpy as np
import shapely.affinity
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
from numba import cuda
from skimage.morphology import skeletonize
from ocrd import OcrdPage
from ocrd_utils import getLogger, tf_disable_interactive_logs
import statistics
@ -53,10 +39,6 @@ try:
import matplotlib.pyplot as plt
except ImportError:
plt = None
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
except ImportError:
TrOCRProcessor = VisionEncoderDecoderModel = None
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
tf_disable_interactive_logs()
@ -290,13 +272,6 @@ class Eynollah:
if self.tr:
loadable.append(('ocr', 'tr'))
loadable.append(('ocr_tr_processor', 'tr'))
# TODO why here and why only for tr?
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
self.device = torch.device("cuda:0")
else:
self.logger.info("Using CPU processing")
self.device = torch.device("cpu")
else:
loadable.append('ocr')
loadable.append('num_to_char')
@ -307,10 +282,16 @@ class Eynollah:
if hasattr(self, 'executor') and getattr(self, 'executor'):
self.executor.shutdown()
self.executor = None
if hasattr(self, 'models') and getattr(self, 'models'):
for model_name in list(self.models):
if self.models[model_name]:
del self.models[model_name]
self.model_zoo.shutdown()
@property
def device(self):
# TODO why here and why only for tr?
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
return torch.device("cuda:0")
self.logger.info("Using CPU processing")
return torch.device("cpu")
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
ret = {}

View file

@ -1,3 +1,41 @@
# pyright: reportPossiblyUnboundVariable=false
from logging import getLogger
from typing import Optional
from pathlib import Path
import os
import json
import gc
import sys
import math
import cv2
import time
from keras.layers import StringLookup
from eynollah.utils.resize import resize_image
from eynollah.utils.utils_ocr import break_curved_line_into_small_pieces_and_then_merge, decode_batch_predictions, fit_text_single_line, get_contours_and_bounding_boxes, get_orientation_moments, preprocess_and_resize_image_for_ocrcnn_model, return_textlines_split_if_needed, rotate_image_with_padding
from .utils import is_image_filename
import xml.etree.ElementTree as ET
import tensorflow as tf
from keras.models import load_model
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import torch
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
if sys.version_info < (3, 10):
import importlib_resources
else:
import importlib.resources as importlib_resources
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
except ImportError:
TrOCRProcessor = VisionEncoderDecoderModel = None
class Eynollah_ocr:
def __init__(
self,
@ -25,6 +63,7 @@ class Eynollah_ocr:
else:
self.min_conf_value_of_textline_text = 0.3
if tr_ocr:
assert TrOCRProcessor
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if self.model_name:

View file

@ -2,7 +2,6 @@ from dataclasses import dataclass
import json
import logging
from pathlib import Path
from types import MappingProxyType
from typing import Dict, Literal, Optional, Tuple, List, Union
from copy import deepcopy
@ -12,6 +11,8 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from eynollah.patch_encoder import PatchEncoder, Patches
SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]
# Dict mapping model_category to dict mapping variant (default is '') to Path
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
@ -134,13 +135,14 @@ class EynollahModelZoo():
def __init__(
self,
basedir: str,
model_overrides: List[Tuple[str, str, str]],
model_overrides: Optional[List[Tuple[str, str, str]]]=None,
) -> None:
self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo')
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS)
if model_overrides:
self.override_models(*model_overrides)
self._loaded: Dict[Tuple[str, str], SomeEynollahModel] = {}
def override_models(self, *model_overrides: Tuple[str, str, str]):
"""
@ -202,7 +204,7 @@ class EynollahModelZoo():
model_category: str,
model_variant: str = '',
model_filename: str = '',
) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
) -> SomeEynollahModel:
"""
Load any model
"""
@ -223,9 +225,16 @@ class EynollahModelZoo():
self.logger.exception(e)
model = load_model(model_path, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
self._loaded[(model_category, model_variant)] = model
return model # type: ignore
def _load_ocr_model(self, variant: str) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
def get_model(self, model_categeory, model_variant) -> SomeEynollahModel:
needle = (model_categeory, model_variant)
if needle not in self._loaded:
raise ValueError('Model/variant "{needle} not previously loaded with "load_model(..)"')
return self._loaded[needle]
def _load_ocr_model(self, variant: str) -> SomeEynollahModel:
"""
Load OCR model
"""
@ -258,3 +267,12 @@ class EynollahModelZoo():
'versions': self.model_versions,
}, indent=2))
def shutdown(self):
"""
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
"""
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
for needle in self._loaded:
if self._loaded[needle]:
del self._loaded[needle]