mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-27 07:44:12 +01:00
organize imports mostly
This commit is contained in:
parent
48d1198d24
commit
d609a532bf
3 changed files with 71 additions and 33 deletions
|
|
@ -8,40 +8,26 @@
|
||||||
document layout analysis (segmentation) with output in PAGE-XML
|
document layout analysis (segmentation) with output in PAGE-XML
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
|
|
||||||
import sys
|
|
||||||
|
|
||||||
if sys.version_info < (3, 10):
|
|
||||||
import importlib_resources
|
|
||||||
else:
|
|
||||||
import importlib.resources as importlib_resources
|
|
||||||
|
|
||||||
from difflib import SequenceMatcher as sq
|
from difflib import SequenceMatcher as sq
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Union,List, Optional, Tuple
|
from typing import Dict, Union,List, Optional, Tuple
|
||||||
import atexit
|
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
import gc
|
import gc
|
||||||
import copy
|
import copy
|
||||||
import json
|
|
||||||
|
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import shapely.affinity
|
import shapely.affinity
|
||||||
from scipy.signal import find_peaks
|
from scipy.signal import find_peaks
|
||||||
from scipy.ndimage import gaussian_filter1d
|
from scipy.ndimage import gaussian_filter1d
|
||||||
from numba import cuda
|
|
||||||
from skimage.morphology import skeletonize
|
from skimage.morphology import skeletonize
|
||||||
from ocrd import OcrdPage
|
|
||||||
from ocrd_utils import getLogger, tf_disable_interactive_logs
|
from ocrd_utils import getLogger, tf_disable_interactive_logs
|
||||||
import statistics
|
import statistics
|
||||||
|
|
||||||
|
|
@ -53,10 +39,6 @@ try:
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
except ImportError:
|
except ImportError:
|
||||||
plt = None
|
plt = None
|
||||||
try:
|
|
||||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
||||||
except ImportError:
|
|
||||||
TrOCRProcessor = VisionEncoderDecoderModel = None
|
|
||||||
|
|
||||||
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||||
tf_disable_interactive_logs()
|
tf_disable_interactive_logs()
|
||||||
|
|
@ -290,13 +272,6 @@ class Eynollah:
|
||||||
if self.tr:
|
if self.tr:
|
||||||
loadable.append(('ocr', 'tr'))
|
loadable.append(('ocr', 'tr'))
|
||||||
loadable.append(('ocr_tr_processor', 'tr'))
|
loadable.append(('ocr_tr_processor', 'tr'))
|
||||||
# TODO why here and why only for tr?
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.logger.info("Using GPU acceleration")
|
|
||||||
self.device = torch.device("cuda:0")
|
|
||||||
else:
|
|
||||||
self.logger.info("Using CPU processing")
|
|
||||||
self.device = torch.device("cpu")
|
|
||||||
else:
|
else:
|
||||||
loadable.append('ocr')
|
loadable.append('ocr')
|
||||||
loadable.append('num_to_char')
|
loadable.append('num_to_char')
|
||||||
|
|
@ -307,10 +282,16 @@ class Eynollah:
|
||||||
if hasattr(self, 'executor') and getattr(self, 'executor'):
|
if hasattr(self, 'executor') and getattr(self, 'executor'):
|
||||||
self.executor.shutdown()
|
self.executor.shutdown()
|
||||||
self.executor = None
|
self.executor = None
|
||||||
if hasattr(self, 'models') and getattr(self, 'models'):
|
self.model_zoo.shutdown()
|
||||||
for model_name in list(self.models):
|
|
||||||
if self.models[model_name]:
|
@property
|
||||||
del self.models[model_name]
|
def device(self):
|
||||||
|
# TODO why here and why only for tr?
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.logger.info("Using GPU acceleration")
|
||||||
|
return torch.device("cuda:0")
|
||||||
|
self.logger.info("Using CPU processing")
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
|
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
|
||||||
ret = {}
|
ret = {}
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,41 @@
|
||||||
|
# pyright: reportPossiblyUnboundVariable=false
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import gc
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
|
||||||
|
from keras.layers import StringLookup
|
||||||
|
|
||||||
|
from eynollah.utils.resize import resize_image
|
||||||
|
from eynollah.utils.utils_ocr import break_curved_line_into_small_pieces_and_then_merge, decode_batch_predictions, fit_text_single_line, get_contours_and_bounding_boxes, get_orientation_moments, preprocess_and_resize_image_for_ocrcnn_model, return_textlines_split_if_needed, rotate_image_with_padding
|
||||||
|
|
||||||
|
from .utils import is_image_filename
|
||||||
|
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
import tensorflow as tf
|
||||||
|
from keras.models import load_model
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
|
||||||
|
if sys.version_info < (3, 10):
|
||||||
|
import importlib_resources
|
||||||
|
else:
|
||||||
|
import importlib.resources as importlib_resources
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||||
|
except ImportError:
|
||||||
|
TrOCRProcessor = VisionEncoderDecoderModel = None
|
||||||
|
|
||||||
class Eynollah_ocr:
|
class Eynollah_ocr:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -25,6 +63,7 @@ class Eynollah_ocr:
|
||||||
else:
|
else:
|
||||||
self.min_conf_value_of_textline_text = 0.3
|
self.min_conf_value_of_textline_text = 0.3
|
||||||
if tr_ocr:
|
if tr_ocr:
|
||||||
|
assert TrOCRProcessor
|
||||||
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
||||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
if self.model_name:
|
if self.model_name:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import Dict, Literal, Optional, Tuple, List, Union
|
from typing import Dict, Literal, Optional, Tuple, List, Union
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
@ -12,6 +11,8 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||||
|
|
||||||
from eynollah.patch_encoder import PatchEncoder, Patches
|
from eynollah.patch_encoder import PatchEncoder, Patches
|
||||||
|
|
||||||
|
SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]
|
||||||
|
|
||||||
|
|
||||||
# Dict mapping model_category to dict mapping variant (default is '') to Path
|
# Dict mapping model_category to dict mapping variant (default is '') to Path
|
||||||
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
|
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
|
||||||
|
|
@ -134,13 +135,14 @@ class EynollahModelZoo():
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
basedir: str,
|
basedir: str,
|
||||||
model_overrides: List[Tuple[str, str, str]],
|
model_overrides: Optional[List[Tuple[str, str, str]]]=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_basedir = Path(basedir)
|
self.model_basedir = Path(basedir)
|
||||||
self.logger = logging.getLogger('eynollah.model_zoo')
|
self.logger = logging.getLogger('eynollah.model_zoo')
|
||||||
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS)
|
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS)
|
||||||
if model_overrides:
|
if model_overrides:
|
||||||
self.override_models(*model_overrides)
|
self.override_models(*model_overrides)
|
||||||
|
self._loaded: Dict[Tuple[str, str], SomeEynollahModel] = {}
|
||||||
|
|
||||||
def override_models(self, *model_overrides: Tuple[str, str, str]):
|
def override_models(self, *model_overrides: Tuple[str, str, str]):
|
||||||
"""
|
"""
|
||||||
|
|
@ -202,7 +204,7 @@ class EynollahModelZoo():
|
||||||
model_category: str,
|
model_category: str,
|
||||||
model_variant: str = '',
|
model_variant: str = '',
|
||||||
model_filename: str = '',
|
model_filename: str = '',
|
||||||
) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
|
) -> SomeEynollahModel:
|
||||||
"""
|
"""
|
||||||
Load any model
|
Load any model
|
||||||
"""
|
"""
|
||||||
|
|
@ -223,9 +225,16 @@ class EynollahModelZoo():
|
||||||
self.logger.exception(e)
|
self.logger.exception(e)
|
||||||
model = load_model(model_path, compile=False, custom_objects={
|
model = load_model(model_path, compile=False, custom_objects={
|
||||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
|
self._loaded[(model_category, model_variant)] = model
|
||||||
return model # type: ignore
|
return model # type: ignore
|
||||||
|
|
||||||
def _load_ocr_model(self, variant: str) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
|
def get_model(self, model_categeory, model_variant) -> SomeEynollahModel:
|
||||||
|
needle = (model_categeory, model_variant)
|
||||||
|
if needle not in self._loaded:
|
||||||
|
raise ValueError('Model/variant "{needle} not previously loaded with "load_model(..)"')
|
||||||
|
return self._loaded[needle]
|
||||||
|
|
||||||
|
def _load_ocr_model(self, variant: str) -> SomeEynollahModel:
|
||||||
"""
|
"""
|
||||||
Load OCR model
|
Load OCR model
|
||||||
"""
|
"""
|
||||||
|
|
@ -258,3 +267,12 @@ class EynollahModelZoo():
|
||||||
'versions': self.model_versions,
|
'versions': self.model_versions,
|
||||||
}, indent=2))
|
}, indent=2))
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""
|
||||||
|
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
|
||||||
|
"""
|
||||||
|
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
||||||
|
for needle in self._loaded:
|
||||||
|
if self._loaded[needle]:
|
||||||
|
del self._loaded[needle]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue