organize imports mostly

This commit is contained in:
kba 2025-10-20 19:46:07 +02:00
parent 48d1198d24
commit d609a532bf
3 changed files with 71 additions and 33 deletions

View file

@ -8,40 +8,26 @@
document layout analysis (segmentation) with output in PAGE-XML document layout analysis (segmentation) with output in PAGE-XML
""" """
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
import sys
if sys.version_info < (3, 10):
import importlib_resources
else:
import importlib.resources as importlib_resources
from difflib import SequenceMatcher as sq from difflib import SequenceMatcher as sq
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import math import math
import os import os
import sys
import time import time
from typing import Dict, Union,List, Optional, Tuple from typing import Dict, Union,List, Optional, Tuple
import atexit
import warnings import warnings
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from multiprocessing import cpu_count from multiprocessing import cpu_count
import gc import gc
import copy import copy
import json
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
import xml.etree.ElementTree as ET
import cv2 import cv2
import numpy as np import numpy as np
import shapely.affinity import shapely.affinity
from scipy.signal import find_peaks from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d from scipy.ndimage import gaussian_filter1d
from numba import cuda
from skimage.morphology import skeletonize from skimage.morphology import skeletonize
from ocrd import OcrdPage
from ocrd_utils import getLogger, tf_disable_interactive_logs from ocrd_utils import getLogger, tf_disable_interactive_logs
import statistics import statistics
@ -53,10 +39,6 @@ try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except ImportError: except ImportError:
plt = None plt = None
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
except ImportError:
TrOCRProcessor = VisionEncoderDecoderModel = None
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1' #os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
tf_disable_interactive_logs() tf_disable_interactive_logs()
@ -290,13 +272,6 @@ class Eynollah:
if self.tr: if self.tr:
loadable.append(('ocr', 'tr')) loadable.append(('ocr', 'tr'))
loadable.append(('ocr_tr_processor', 'tr')) loadable.append(('ocr_tr_processor', 'tr'))
# TODO why here and why only for tr?
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
self.device = torch.device("cuda:0")
else:
self.logger.info("Using CPU processing")
self.device = torch.device("cpu")
else: else:
loadable.append('ocr') loadable.append('ocr')
loadable.append('num_to_char') loadable.append('num_to_char')
@ -307,10 +282,16 @@ class Eynollah:
if hasattr(self, 'executor') and getattr(self, 'executor'): if hasattr(self, 'executor') and getattr(self, 'executor'):
self.executor.shutdown() self.executor.shutdown()
self.executor = None self.executor = None
if hasattr(self, 'models') and getattr(self, 'models'): self.model_zoo.shutdown()
for model_name in list(self.models):
if self.models[model_name]: @property
del self.models[model_name] def device(self):
# TODO why here and why only for tr?
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
return torch.device("cuda:0")
self.logger.info("Using CPU processing")
return torch.device("cpu")
def cache_images(self, image_filename=None, image_pil=None, dpi=None): def cache_images(self, image_filename=None, image_pil=None, dpi=None):
ret = {} ret = {}

View file

@ -1,3 +1,41 @@
# pyright: reportPossiblyUnboundVariable=false
from logging import getLogger
from typing import Optional
from pathlib import Path
import os
import json
import gc
import sys
import math
import cv2
import time
from keras.layers import StringLookup
from eynollah.utils.resize import resize_image
from eynollah.utils.utils_ocr import break_curved_line_into_small_pieces_and_then_merge, decode_batch_predictions, fit_text_single_line, get_contours_and_bounding_boxes, get_orientation_moments, preprocess_and_resize_image_for_ocrcnn_model, return_textlines_split_if_needed, rotate_image_with_padding
from .utils import is_image_filename
import xml.etree.ElementTree as ET
import tensorflow as tf
from keras.models import load_model
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import torch
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
if sys.version_info < (3, 10):
import importlib_resources
else:
import importlib.resources as importlib_resources
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
except ImportError:
TrOCRProcessor = VisionEncoderDecoderModel = None
class Eynollah_ocr: class Eynollah_ocr:
def __init__( def __init__(
self, self,
@ -25,6 +63,7 @@ class Eynollah_ocr:
else: else:
self.min_conf_value_of_textline_text = 0.3 self.min_conf_value_of_textline_text = 0.3
if tr_ocr: if tr_ocr:
assert TrOCRProcessor
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if self.model_name: if self.model_name:

View file

@ -2,7 +2,6 @@ from dataclasses import dataclass
import json import json
import logging import logging
from pathlib import Path from pathlib import Path
from types import MappingProxyType
from typing import Dict, Literal, Optional, Tuple, List, Union from typing import Dict, Literal, Optional, Tuple, List, Union
from copy import deepcopy from copy import deepcopy
@ -12,6 +11,8 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from eynollah.patch_encoder import PatchEncoder, Patches from eynollah.patch_encoder import PatchEncoder, Patches
SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]
# Dict mapping model_category to dict mapping variant (default is '') to Path # Dict mapping model_category to dict mapping variant (default is '') to Path
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = { DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
@ -134,13 +135,14 @@ class EynollahModelZoo():
def __init__( def __init__(
self, self,
basedir: str, basedir: str,
model_overrides: List[Tuple[str, str, str]], model_overrides: Optional[List[Tuple[str, str, str]]]=None,
) -> None: ) -> None:
self.model_basedir = Path(basedir) self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo') self.logger = logging.getLogger('eynollah.model_zoo')
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS) self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS)
if model_overrides: if model_overrides:
self.override_models(*model_overrides) self.override_models(*model_overrides)
self._loaded: Dict[Tuple[str, str], SomeEynollahModel] = {}
def override_models(self, *model_overrides: Tuple[str, str, str]): def override_models(self, *model_overrides: Tuple[str, str, str]):
""" """
@ -202,7 +204,7 @@ class EynollahModelZoo():
model_category: str, model_category: str,
model_variant: str = '', model_variant: str = '',
model_filename: str = '', model_filename: str = '',
) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]: ) -> SomeEynollahModel:
""" """
Load any model Load any model
""" """
@ -223,9 +225,16 @@ class EynollahModelZoo():
self.logger.exception(e) self.logger.exception(e)
model = load_model(model_path, compile=False, custom_objects={ model = load_model(model_path, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches}) "PatchEncoder": PatchEncoder, "Patches": Patches})
self._loaded[(model_category, model_variant)] = model
return model # type: ignore return model # type: ignore
def _load_ocr_model(self, variant: str) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]: def get_model(self, model_categeory, model_variant) -> SomeEynollahModel:
needle = (model_categeory, model_variant)
if needle not in self._loaded:
raise ValueError('Model/variant "{needle} not previously loaded with "load_model(..)"')
return self._loaded[needle]
def _load_ocr_model(self, variant: str) -> SomeEynollahModel:
""" """
Load OCR model Load OCR model
""" """
@ -258,3 +267,12 @@ class EynollahModelZoo():
'versions': self.model_versions, 'versions': self.model_versions,
}, indent=2)) }, indent=2))
def shutdown(self):
"""
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
"""
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
for needle in self._loaded:
if self._loaded[needle]:
del self._loaded[needle]