mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
- ModelZoo: drop `num_to_char` and `characters` model types,
also drop `_load_characters()` and `_load_num_to_char()` loaders
- `ModelZoo.load_models()`: use Predictor for `ocr` models, too
- `ModelZoo.load_model()`: delegate runtime/inference conversion of
OCR models to `eynollah.training.models.cnn_rnn_ocr_model4inference`
- `training.models`: add (purely functional) Keras layer `CTCDecoder`
for inference on top of softmax output, but using TF backend
function instead of (broken) `Keras.backend.ctc_decode()`, while
switching to beam search (instead of greedy) and also returning
decoded path probability
- `training.models.cnn_rnn_ocr_model()` w/ `inference=True`:
* add kwarg `characters_txt_file` for file path of character set
* configure secondary tensor path on OCR graph for binarized input
(additional input `image_bin`, averaging softmax outputs)
* use new `CTCDecoder` layer and inverse `StringLookup` layer to
decode from softmax output to tf.string; so inference models
now have 2 inputs (RGB, binarized) and 2 outputs (text, prob)
* since `np.dtype=object` cannot be handled by SharedMemory (as
needed by Predictor queues), also replace tf.string by tf.uint8
arrays
* use this for `training convert` for OCR models w/ `--rebuild`
- `training.models.cnn_rnn_ocr_model4inference`:
* new function which does the same but loads an existing OCR model
in training configuration (i.e. without prior `inference=True`)
* use this for `training convert` for OCR models w/o `--rebuild`
396 lines
15 KiB
Python
396 lines
15 KiB
Python
import os
|
||
import json
|
||
import logging
|
||
from copy import deepcopy
|
||
from pathlib import Path
|
||
from fnmatch import fnmatchcase
|
||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||
|
||
from tabulate import tabulate
|
||
|
||
from ..predictor import Predictor
|
||
from .specs import EynollahModelSpecSet
|
||
from .default_specs import DEFAULT_MODEL_SPECS
|
||
from .types import AnyModel, T
|
||
|
||
|
||
MODEL_VRAM_LIMITS = {
|
||
"binarization": 868, # due to bs 5
|
||
"enhancement": 980, # due to bs 3
|
||
"col_classifier": 210,
|
||
"page": 618,
|
||
"textline": 1680, # 954 for bs 1
|
||
"region_1_2": 1580,
|
||
"region_fl_np": 1756,
|
||
"table": 1818,
|
||
"reading_order": 632,
|
||
"ocr": 850,
|
||
}
|
||
|
||
class EynollahModelZoo:
|
||
"""
|
||
Wrapper class that handles storage and loading of models for all eynollah runners.
|
||
"""
|
||
|
||
model_basedir: Path
|
||
specs: EynollahModelSpecSet
|
||
|
||
def __init__(
|
||
self,
|
||
basedir: str,
|
||
model_overrides: Optional[List[Tuple[str, str, str]]] = None,
|
||
) -> None:
|
||
self.model_basedir = Path(basedir).resolve()
|
||
self.logger = logging.getLogger('eynollah.model_zoo')
|
||
if not self.model_basedir.exists():
|
||
self.logger.warning(f"Model basedir does not exist: {basedir}. Set eynollah --model-basedir to the correct directory.")
|
||
self.specs = deepcopy(DEFAULT_MODEL_SPECS)
|
||
self._overrides = []
|
||
if model_overrides:
|
||
self.override_models(*model_overrides)
|
||
self._loaded: Dict[str, Union[Predictor, AnyModel]] = {}
|
||
|
||
@property
|
||
def model_overrides(self):
|
||
return self._overrides
|
||
|
||
def override_models(
|
||
self,
|
||
*model_overrides: Tuple[str, str, str],
|
||
):
|
||
"""
|
||
Override the default model versions
|
||
"""
|
||
for model_category, model_variant, model_filename in model_overrides:
|
||
spec = self.specs.get(model_category, model_variant)
|
||
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
|
||
self.specs.get(model_category, model_variant).filename = str(Path(model_filename).resolve())
|
||
self._overrides += model_overrides
|
||
|
||
def model_path(
|
||
self,
|
||
model_category: str,
|
||
model_variant: str = '',
|
||
absolute: bool = True,
|
||
) -> Path:
|
||
"""
|
||
Translate model_{type,variant} tuple into an absolute (or relative) Path
|
||
"""
|
||
spec = self.specs.get(model_category, model_variant)
|
||
if spec.category in ('characters', 'num_to_char'):
|
||
return self.model_path('ocr') / spec.filename
|
||
if not Path(spec.filename).is_absolute() and absolute:
|
||
model_path = Path(self.model_basedir).joinpath(spec.filename)
|
||
else:
|
||
model_path = Path(spec.filename)
|
||
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
||
# prefer SavedModel over HDF5 format if it exists
|
||
model_path = Path(model_path.stem)
|
||
if model_path.with_suffix('.onnx').exists():
|
||
# prefer ONNX over SavedModel format if it exists
|
||
model_path = model_path.with_suffix('.onnx')
|
||
|
||
return model_path
|
||
|
||
def load_models(
|
||
self,
|
||
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
|
||
device: str = '',
|
||
) -> Dict:
|
||
"""
|
||
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
|
||
"""
|
||
ret = {} # cannot use self._loaded here, yet – first spawn all predictors
|
||
for load_args in all_load_args:
|
||
load_kwargs = dict(device=device)
|
||
if isinstance(load_args, str):
|
||
model_category, model_variant = load_args, ""
|
||
elif len(load_args) > 2:
|
||
# for calls to self.model_path
|
||
self.override_models(load_args)
|
||
# for calls to Predictor.load_model
|
||
model_category, model_variant, model_path = load_args
|
||
load_kwargs["model_variant"] = model_variant
|
||
load_kwargs["model_path_override"] = model_path
|
||
else:
|
||
model_category, model_variant = load_args
|
||
load_kwargs["model_variant"] = model_variant
|
||
|
||
if model_category.endswith('_resized'):
|
||
model_category = model_category[:-8]
|
||
load_kwargs["resized"] = True
|
||
elif model_category.endswith('_patched'):
|
||
model_category = model_category[:-8]
|
||
load_kwargs["patched"] = True
|
||
|
||
if model_category == 'ocr' and model_variant == 'tr':
|
||
model = self._load_ocr_model(variant=model_variant, device=device)
|
||
elif model_category == 'trocr_processor':
|
||
from transformers import TrOCRProcessor
|
||
model_path = self.model_path(model_category, model_variant)
|
||
model = TrOCRProcessor.from_pretrained(model_path)
|
||
else:
|
||
model = Predictor(self.logger, self)
|
||
model.load_model(model_category, **load_kwargs)
|
||
|
||
ret[model_category] = model
|
||
self._loaded.update(ret)
|
||
return self._loaded
|
||
|
||
def load_model(
|
||
self,
|
||
model_category: str,
|
||
model_variant: str = '',
|
||
model_path_override: Optional[str] = None,
|
||
patched: bool = False,
|
||
resized: bool = False,
|
||
device: str = '',
|
||
) -> AnyModel:
|
||
"""
|
||
Load any model
|
||
"""
|
||
if model_path_override:
|
||
self.override_models((model_category, model_variant, model_path_override))
|
||
model_path = self.model_path(model_category, model_variant)
|
||
|
||
if model_path.is_dir() and (model_path / "keras_metadata.pb").exists():
|
||
# Keras model
|
||
model = self._load_keras_model(model_category, model_path, device=device)
|
||
elif model_path.is_dir():
|
||
# TF-Serving model
|
||
model = self._load_serving_model(model_category, model_path, device=device)
|
||
elif model_path.suffix == '.onnx':
|
||
# ONNX model
|
||
model = self._load_onnx_model(model_category, model_path, device=device)
|
||
else:
|
||
raise ValueError("unknown model type for '%s'" % str(model_path))
|
||
model._name = model_category
|
||
return model
|
||
|
||
def get(self, model_category: str) -> Union[Predictor, AnyModel]:
|
||
if model_category not in self._loaded:
|
||
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
|
||
return self._loaded[model_category]
|
||
|
||
def _configure_tf_device(self, model_category, device=''):
|
||
from ocrd_utils import tf_disable_interactive_logs
|
||
tf_disable_interactive_logs()
|
||
import tensorflow as tf
|
||
|
||
cuda = False
|
||
try:
|
||
gpus = tf.config.list_physical_devices('GPU')
|
||
if device:
|
||
if ':' in device:
|
||
for spec in device.split(','):
|
||
cat, dev = spec.split(':')
|
||
if fnmatchcase(model_category, cat):
|
||
device = dev
|
||
break
|
||
if device == 'CPU':
|
||
gpus = []
|
||
else:
|
||
assert device.startswith('GPU')
|
||
gpus = [gpus[int(device[3:])]]
|
||
else:
|
||
gpus = gpus[:1] # TF will always use first allowable
|
||
tf.config.set_visible_devices(gpus, 'GPU')
|
||
for device in gpus:
|
||
# tf.config.experimental.set_memory_growth(device, True)
|
||
# dynamic growth never frees memory (to avoid fragmentation),
|
||
# so the VRAM requirements end up much larger than feasible
|
||
# (for small GPUs); so try hard (calibrated) limits instead:
|
||
tf.config.set_logical_device_configuration(
|
||
device,
|
||
[tf.config.LogicalDeviceConfiguration(
|
||
memory_limit=MODEL_VRAM_LIMITS[model_category])])
|
||
vendor_name = (
|
||
tf.config.experimental.get_device_details(device)
|
||
.get('device_name', 'unknown'))
|
||
cuda = True
|
||
self.logger.info("using GPU %s (%s) for model %s",
|
||
device.name,
|
||
vendor_name,
|
||
model_category # + (
|
||
# "_patched" if patched else
|
||
# "_resized" if resized else "")
|
||
)
|
||
except RuntimeError:
|
||
self.logger.exception("cannot configure GPU devices")
|
||
if not cuda:
|
||
self.logger.warning("no GPU device available")
|
||
|
||
def _load_keras_model(self, model_category, model_path, device=''):
|
||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||
from ocrd_utils import tf_disable_interactive_logs
|
||
tf_disable_interactive_logs()
|
||
|
||
from tensorflow.keras.models import load_model
|
||
from tensorflow.keras.models import Model as KerasModel
|
||
|
||
from ..training.models import cnn_rnn_ocr_model4inference
|
||
|
||
self._configure_tf_device(model_category, device=device)
|
||
|
||
model = load_model(model_path, compile=False)
|
||
assert isinstance(model, KerasModel)
|
||
|
||
# from ..patch_encoder import (
|
||
# wrap_layout_model_patched,
|
||
# wrap_layout_model_resized,
|
||
# )
|
||
# if resized:
|
||
# model = wrap_layout_model_resized(model)
|
||
# model._name = model_category + '_resized'
|
||
# elif patched:
|
||
# model = wrap_layout_model_patched(model)
|
||
# model._name = model_category + '_patched'
|
||
|
||
if model_category == 'ocr':
|
||
# cnn-rnn-ocr task model may not be in inference mode, yet
|
||
model = cnn_rnn_ocr_model4inference(model, model_path)
|
||
|
||
model.make_predict_function()
|
||
|
||
return model
|
||
|
||
def _load_serving_model(self, model_category, model_path, device=''):
|
||
from ocrd_utils import tf_disable_interactive_logs
|
||
tf_disable_interactive_logs()
|
||
import tensorflow as tf
|
||
|
||
self._configure_tf_device(model_category, device=device)
|
||
model = tf.saved_model.load(model_path)
|
||
model.predict_on_batch = model.serve
|
||
model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape)
|
||
|
||
return model
|
||
|
||
def _load_onnx_model(self, model_category, model_path, device=''):
|
||
import onnxruntime as ort
|
||
import numpy as np
|
||
|
||
providers = ort.get_available_providers()
|
||
if device:
|
||
if ':' in device:
|
||
for spec in device.split(','):
|
||
cat, dev = spec.split(':')
|
||
if fnmatchcase(model_category, cat):
|
||
device = dev
|
||
break
|
||
if device == 'CPU':
|
||
gpu = -1
|
||
else:
|
||
assert device.startswith('GPU')
|
||
gpu = int(device[3:] or "0")
|
||
else:
|
||
gpu = 0 # try first allowable
|
||
# configure and prioritise
|
||
if 'CUDAExecutionProvider' in providers:
|
||
providers.remove('CUDAExecutionProvider')
|
||
if gpu >= 0:
|
||
providers = [('CUDAExecutionProvider', {
|
||
'device_id': gpu,
|
||
# 'arena_extend_strategy': 'kNextPowerOfTwo',
|
||
'gpu_mem_limit': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
|
||
# 'cudnn_conv_algo_search': 'EXHAUSTIVE',
|
||
# 'do_copy_in_default_stream': True,
|
||
# ...
|
||
})] + providers
|
||
if 'TensorrtExecutionProvider' in providers:
|
||
providers.remove('TensorrtExecutionProvider')
|
||
if gpu >= 0:
|
||
providers = [('TensorrtExecutionProvider', {
|
||
'device_id': gpu,
|
||
'trt_max_workspace_size': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
|
||
# 'trt_fp16_enable': True,
|
||
# 'trt_engine_cache_enable': True,
|
||
# 'trt_timing_cache_enable': True,
|
||
# ...
|
||
})] + providers
|
||
model = ort.InferenceSession(
|
||
model_path,
|
||
providers=providers)
|
||
# FIXME: notify about selected provider/device
|
||
input_name = model.get_inputs()[0].name
|
||
output_name = model.get_outputs()[0].name
|
||
def predict_onnx(inputs):
|
||
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
|
||
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
|
||
inputs = inputs.astype(np.float32)
|
||
return model.run(
|
||
[output_name], {input_name: inputs})[0]
|
||
model.predict_on_batch = predict_onnx
|
||
model.input_shape = model.get_inputs()[0].shape
|
||
|
||
return model
|
||
|
||
def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel:
|
||
"""
|
||
Load OCR model
|
||
"""
|
||
model_dir = self.model_path('ocr', variant)
|
||
if variant == 'tr':
|
||
from transformers import VisionEncoderDecoderModel
|
||
import torch
|
||
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
|
||
assert isinstance(model, VisionEncoderDecoderModel)
|
||
device0 = torch.device('cpu')
|
||
if not device and torch.cuda.is_available():
|
||
device = 'GPU' # try
|
||
if device and ':' in device:
|
||
for spec in device.split(','):
|
||
cat, dev = spec.split(':')
|
||
if fnmatchcase('ocr', cat):
|
||
device = dev
|
||
break
|
||
if device and device.startswith('GPU'):
|
||
try:
|
||
device0 = torch.device('cuda', int(device[3:] or 0))
|
||
name = torch.cuda.get_device_name(device0)
|
||
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
|
||
except:
|
||
self.logger.exception("cannot configure GPU device")
|
||
device0 = torch.device('cpu')
|
||
if device0.type == 'cuda':
|
||
model.to(device0)
|
||
else:
|
||
self.logger.warning("no GPU device available")
|
||
return model
|
||
|
||
return self.load_model('ocr', model_variant=variant, device=device)
|
||
|
||
def __str__(self):
|
||
return tabulate(
|
||
[
|
||
[
|
||
spec.type,
|
||
spec.category,
|
||
spec.variant,
|
||
spec.help,
|
||
f'Yes, at {self.model_path(spec.category, spec.variant)}'
|
||
if self.model_path(spec.category, spec.variant).exists()
|
||
else f'No, download {spec.dist_url}',
|
||
# self.model_path(spec.category, spec.variant),
|
||
]
|
||
for spec in sorted(self.specs.specs, key=lambda x: x.dist_url)
|
||
],
|
||
headers=[
|
||
'Type',
|
||
'Category',
|
||
'Variant',
|
||
'Help',
|
||
'Used in',
|
||
'Installed',
|
||
],
|
||
tablefmt='github',
|
||
)
|
||
|
||
def shutdown(self):
|
||
"""
|
||
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
|
||
"""
|
||
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
||
for needle in list(self._loaded.keys()):
|
||
if isinstance(self._loaded[needle], Predictor):
|
||
self._loaded[needle].shutdown()
|