mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-26 23:34:13 +01:00
organize imports mostly
This commit is contained in:
parent
48d1198d24
commit
d609a532bf
3 changed files with 71 additions and 33 deletions
|
|
@ -8,40 +8,26 @@
|
|||
document layout analysis (segmentation) with output in PAGE-XML
|
||||
"""
|
||||
|
||||
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
import importlib_resources
|
||||
else:
|
||||
import importlib.resources as importlib_resources
|
||||
|
||||
from difflib import SequenceMatcher as sq
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, Union,List, Optional, Tuple
|
||||
import atexit
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from multiprocessing import cpu_count
|
||||
import gc
|
||||
import copy
|
||||
import json
|
||||
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import xml.etree.ElementTree as ET
|
||||
import cv2
|
||||
import numpy as np
|
||||
import shapely.affinity
|
||||
from scipy.signal import find_peaks
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
from numba import cuda
|
||||
from skimage.morphology import skeletonize
|
||||
from ocrd import OcrdPage
|
||||
from ocrd_utils import getLogger, tf_disable_interactive_logs
|
||||
import statistics
|
||||
|
||||
|
|
@ -53,10 +39,6 @@ try:
|
|||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
plt = None
|
||||
try:
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
except ImportError:
|
||||
TrOCRProcessor = VisionEncoderDecoderModel = None
|
||||
|
||||
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
tf_disable_interactive_logs()
|
||||
|
|
@ -290,13 +272,6 @@ class Eynollah:
|
|||
if self.tr:
|
||||
loadable.append(('ocr', 'tr'))
|
||||
loadable.append(('ocr_tr_processor', 'tr'))
|
||||
# TODO why here and why only for tr?
|
||||
if torch.cuda.is_available():
|
||||
self.logger.info("Using GPU acceleration")
|
||||
self.device = torch.device("cuda:0")
|
||||
else:
|
||||
self.logger.info("Using CPU processing")
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
loadable.append('ocr')
|
||||
loadable.append('num_to_char')
|
||||
|
|
@ -307,10 +282,16 @@ class Eynollah:
|
|||
if hasattr(self, 'executor') and getattr(self, 'executor'):
|
||||
self.executor.shutdown()
|
||||
self.executor = None
|
||||
if hasattr(self, 'models') and getattr(self, 'models'):
|
||||
for model_name in list(self.models):
|
||||
if self.models[model_name]:
|
||||
del self.models[model_name]
|
||||
self.model_zoo.shutdown()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# TODO why here and why only for tr?
|
||||
if torch.cuda.is_available():
|
||||
self.logger.info("Using GPU acceleration")
|
||||
return torch.device("cuda:0")
|
||||
self.logger.info("Using CPU processing")
|
||||
return torch.device("cpu")
|
||||
|
||||
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
|
||||
ret = {}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,41 @@
|
|||
# pyright: reportPossiblyUnboundVariable=false
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
import sys
|
||||
import math
|
||||
import cv2
|
||||
import time
|
||||
|
||||
from keras.layers import StringLookup
|
||||
|
||||
from eynollah.utils.resize import resize_image
|
||||
from eynollah.utils.utils_ocr import break_curved_line_into_small_pieces_and_then_merge, decode_batch_predictions, fit_text_single_line, get_contours_and_bounding_boxes, get_orientation_moments, preprocess_and_resize_image_for_ocrcnn_model, return_textlines_split_if_needed, rotate_image_with_padding
|
||||
|
||||
from .utils import is_image_filename
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
import tensorflow as tf
|
||||
from keras.models import load_model
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
|
||||
if sys.version_info < (3, 10):
|
||||
import importlib_resources
|
||||
else:
|
||||
import importlib.resources as importlib_resources
|
||||
|
||||
try:
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
except ImportError:
|
||||
TrOCRProcessor = VisionEncoderDecoderModel = None
|
||||
|
||||
class Eynollah_ocr:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -25,6 +63,7 @@ class Eynollah_ocr:
|
|||
else:
|
||||
self.min_conf_value_of_textline_text = 0.3
|
||||
if tr_ocr:
|
||||
assert TrOCRProcessor
|
||||
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
if self.model_name:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
from typing import Dict, Literal, Optional, Tuple, List, Union
|
||||
from copy import deepcopy
|
||||
|
||||
|
|
@ -12,6 +11,8 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|||
|
||||
from eynollah.patch_encoder import PatchEncoder, Patches
|
||||
|
||||
SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]
|
||||
|
||||
|
||||
# Dict mapping model_category to dict mapping variant (default is '') to Path
|
||||
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
|
||||
|
|
@ -134,13 +135,14 @@ class EynollahModelZoo():
|
|||
def __init__(
|
||||
self,
|
||||
basedir: str,
|
||||
model_overrides: List[Tuple[str, str, str]],
|
||||
model_overrides: Optional[List[Tuple[str, str, str]]]=None,
|
||||
) -> None:
|
||||
self.model_basedir = Path(basedir)
|
||||
self.logger = logging.getLogger('eynollah.model_zoo')
|
||||
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS)
|
||||
if model_overrides:
|
||||
self.override_models(*model_overrides)
|
||||
self._loaded: Dict[Tuple[str, str], SomeEynollahModel] = {}
|
||||
|
||||
def override_models(self, *model_overrides: Tuple[str, str, str]):
|
||||
"""
|
||||
|
|
@ -202,7 +204,7 @@ class EynollahModelZoo():
|
|||
model_category: str,
|
||||
model_variant: str = '',
|
||||
model_filename: str = '',
|
||||
) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
|
||||
) -> SomeEynollahModel:
|
||||
"""
|
||||
Load any model
|
||||
"""
|
||||
|
|
@ -223,9 +225,16 @@ class EynollahModelZoo():
|
|||
self.logger.exception(e)
|
||||
model = load_model(model_path, compile=False, custom_objects={
|
||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
self._loaded[(model_category, model_variant)] = model
|
||||
return model # type: ignore
|
||||
|
||||
def _load_ocr_model(self, variant: str) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
|
||||
def get_model(self, model_categeory, model_variant) -> SomeEynollahModel:
|
||||
needle = (model_categeory, model_variant)
|
||||
if needle not in self._loaded:
|
||||
raise ValueError('Model/variant "{needle} not previously loaded with "load_model(..)"')
|
||||
return self._loaded[needle]
|
||||
|
||||
def _load_ocr_model(self, variant: str) -> SomeEynollahModel:
|
||||
"""
|
||||
Load OCR model
|
||||
"""
|
||||
|
|
@ -258,3 +267,12 @@ class EynollahModelZoo():
|
|||
'versions': self.model_versions,
|
||||
}, indent=2))
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
|
||||
"""
|
||||
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
||||
for needle in self._loaded:
|
||||
if self._loaded[needle]:
|
||||
del self._loaded[needle]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue