eynollah/src/eynollah/model_zoo/model_zoo.py
Robert Sachunsky c79b73dcc8 cnn-rnn-ocr: move CTC decoder and string decoder to inference model…
- ModelZoo: drop `num_to_char` and `characters` model types,
  also drop `_load_characters()` and `_load_num_to_char()` loaders
- `ModelZoo.load_models()`: use Predictor for `ocr` models, too
- `ModelZoo.load_model()`: delegate runtime/inference conversion of
  OCR models to `eynollah.training.models.cnn_rnn_ocr_model4inference`
- `training.models`: add (purely functional) Keras layer `CTCDecoder`
  for inference on top of softmax output, but using TF backend
  function instead of (broken) `Keras.backend.ctc_decode()`, while
  switching to beam search (instead of greedy) and also returning
  decoded path probability
- `training.models.cnn_rnn_ocr_model()` w/ `inference=True`:
  * add kwarg `characters_txt_file` for file path of character set
  * configure secondary tensor path on OCR graph for binarized input
    (additional input `image_bin`, averaging softmax outputs)
  * use new `CTCDecoder` layer and inverse `StringLookup` layer to
    decode from softmax output to tf.string; so inference models
    now have 2 inputs (RGB, binarized) and 2 outputs (text, prob)
  * since `np.dtype=object` cannot be handled by SharedMemory (as
    needed by Predictor queues), also replace tf.string by tf.uint8
    arrays
  * use this for `training convert` for OCR models w/ `--rebuild`
- `training.models.cnn_rnn_ocr_model4inference`:
  * new function which does the same but loads an existing OCR model
    in training configuration (i.e. without prior `inference=True`)
  * use this for `training convert` for OCR models w/o `--rebuild`
2026-06-02 20:26:42 +02:00

396 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import logging
from copy import deepcopy
from pathlib import Path
from fnmatch import fnmatchcase
from typing import Dict, List, Optional, Tuple, Type, Union
from tabulate import tabulate
from ..predictor import Predictor
from .specs import EynollahModelSpecSet
from .default_specs import DEFAULT_MODEL_SPECS
from .types import AnyModel, T
MODEL_VRAM_LIMITS = {
"binarization": 868, # due to bs 5
"enhancement": 980, # due to bs 3
"col_classifier": 210,
"page": 618,
"textline": 1680, # 954 for bs 1
"region_1_2": 1580,
"region_fl_np": 1756,
"table": 1818,
"reading_order": 632,
"ocr": 850,
}
class EynollahModelZoo:
"""
Wrapper class that handles storage and loading of models for all eynollah runners.
"""
model_basedir: Path
specs: EynollahModelSpecSet
def __init__(
self,
basedir: str,
model_overrides: Optional[List[Tuple[str, str, str]]] = None,
) -> None:
self.model_basedir = Path(basedir).resolve()
self.logger = logging.getLogger('eynollah.model_zoo')
if not self.model_basedir.exists():
self.logger.warning(f"Model basedir does not exist: {basedir}. Set eynollah --model-basedir to the correct directory.")
self.specs = deepcopy(DEFAULT_MODEL_SPECS)
self._overrides = []
if model_overrides:
self.override_models(*model_overrides)
self._loaded: Dict[str, Union[Predictor, AnyModel]] = {}
@property
def model_overrides(self):
return self._overrides
def override_models(
self,
*model_overrides: Tuple[str, str, str],
):
"""
Override the default model versions
"""
for model_category, model_variant, model_filename in model_overrides:
spec = self.specs.get(model_category, model_variant)
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
self.specs.get(model_category, model_variant).filename = str(Path(model_filename).resolve())
self._overrides += model_overrides
def model_path(
self,
model_category: str,
model_variant: str = '',
absolute: bool = True,
) -> Path:
"""
Translate model_{type,variant} tuple into an absolute (or relative) Path
"""
spec = self.specs.get(model_category, model_variant)
if spec.category in ('characters', 'num_to_char'):
return self.model_path('ocr') / spec.filename
if not Path(spec.filename).is_absolute() and absolute:
model_path = Path(self.model_basedir).joinpath(spec.filename)
else:
model_path = Path(spec.filename)
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists
model_path = Path(model_path.stem)
if model_path.with_suffix('.onnx').exists():
# prefer ONNX over SavedModel format if it exists
model_path = model_path.with_suffix('.onnx')
return model_path
def load_models(
self,
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
device: str = '',
) -> Dict:
"""
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
"""
ret = {} # cannot use self._loaded here, yet first spawn all predictors
for load_args in all_load_args:
load_kwargs = dict(device=device)
if isinstance(load_args, str):
model_category, model_variant = load_args, ""
elif len(load_args) > 2:
# for calls to self.model_path
self.override_models(load_args)
# for calls to Predictor.load_model
model_category, model_variant, model_path = load_args
load_kwargs["model_variant"] = model_variant
load_kwargs["model_path_override"] = model_path
else:
model_category, model_variant = load_args
load_kwargs["model_variant"] = model_variant
if model_category.endswith('_resized'):
model_category = model_category[:-8]
load_kwargs["resized"] = True
elif model_category.endswith('_patched'):
model_category = model_category[:-8]
load_kwargs["patched"] = True
if model_category == 'ocr' and model_variant == 'tr':
model = self._load_ocr_model(variant=model_variant, device=device)
elif model_category == 'trocr_processor':
from transformers import TrOCRProcessor
model_path = self.model_path(model_category, model_variant)
model = TrOCRProcessor.from_pretrained(model_path)
else:
model = Predictor(self.logger, self)
model.load_model(model_category, **load_kwargs)
ret[model_category] = model
self._loaded.update(ret)
return self._loaded
def load_model(
self,
model_category: str,
model_variant: str = '',
model_path_override: Optional[str] = None,
patched: bool = False,
resized: bool = False,
device: str = '',
) -> AnyModel:
"""
Load any model
"""
if model_path_override:
self.override_models((model_category, model_variant, model_path_override))
model_path = self.model_path(model_category, model_variant)
if model_path.is_dir() and (model_path / "keras_metadata.pb").exists():
# Keras model
model = self._load_keras_model(model_category, model_path, device=device)
elif model_path.is_dir():
# TF-Serving model
model = self._load_serving_model(model_category, model_path, device=device)
elif model_path.suffix == '.onnx':
# ONNX model
model = self._load_onnx_model(model_category, model_path, device=device)
else:
raise ValueError("unknown model type for '%s'" % str(model_path))
model._name = model_category
return model
def get(self, model_category: str) -> Union[Predictor, AnyModel]:
if model_category not in self._loaded:
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
return self._loaded[model_category]
def _configure_tf_device(self, model_category, device=''):
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf
cuda = False
try:
gpus = tf.config.list_physical_devices('GPU')
if device:
if ':' in device:
for spec in device.split(','):
cat, dev = spec.split(':')
if fnmatchcase(model_category, cat):
device = dev
break
if device == 'CPU':
gpus = []
else:
assert device.startswith('GPU')
gpus = [gpus[int(device[3:])]]
else:
gpus = gpus[:1] # TF will always use first allowable
tf.config.set_visible_devices(gpus, 'GPU')
for device in gpus:
# tf.config.experimental.set_memory_growth(device, True)
# dynamic growth never frees memory (to avoid fragmentation),
# so the VRAM requirements end up much larger than feasible
# (for small GPUs); so try hard (calibrated) limits instead:
tf.config.set_logical_device_configuration(
device,
[tf.config.LogicalDeviceConfiguration(
memory_limit=MODEL_VRAM_LIMITS[model_category])])
vendor_name = (
tf.config.experimental.get_device_details(device)
.get('device_name', 'unknown'))
cuda = True
self.logger.info("using GPU %s (%s) for model %s",
device.name,
vendor_name,
model_category # + (
# "_patched" if patched else
# "_resized" if resized else "")
)
except RuntimeError:
self.logger.exception("cannot configure GPU devices")
if not cuda:
self.logger.warning("no GPU device available")
def _load_keras_model(self, model_category, model_path, device=''):
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
from tensorflow.keras.models import load_model
from tensorflow.keras.models import Model as KerasModel
from ..training.models import cnn_rnn_ocr_model4inference
self._configure_tf_device(model_category, device=device)
model = load_model(model_path, compile=False)
assert isinstance(model, KerasModel)
# from ..patch_encoder import (
# wrap_layout_model_patched,
# wrap_layout_model_resized,
# )
# if resized:
# model = wrap_layout_model_resized(model)
# model._name = model_category + '_resized'
# elif patched:
# model = wrap_layout_model_patched(model)
# model._name = model_category + '_patched'
if model_category == 'ocr':
# cnn-rnn-ocr task model may not be in inference mode, yet
model = cnn_rnn_ocr_model4inference(model, model_path)
model.make_predict_function()
return model
def _load_serving_model(self, model_category, model_path, device=''):
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf
self._configure_tf_device(model_category, device=device)
model = tf.saved_model.load(model_path)
model.predict_on_batch = model.serve
model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape)
return model
def _load_onnx_model(self, model_category, model_path, device=''):
import onnxruntime as ort
import numpy as np
providers = ort.get_available_providers()
if device:
if ':' in device:
for spec in device.split(','):
cat, dev = spec.split(':')
if fnmatchcase(model_category, cat):
device = dev
break
if device == 'CPU':
gpu = -1
else:
assert device.startswith('GPU')
gpu = int(device[3:] or "0")
else:
gpu = 0 # try first allowable
# configure and prioritise
if 'CUDAExecutionProvider' in providers:
providers.remove('CUDAExecutionProvider')
if gpu >= 0:
providers = [('CUDAExecutionProvider', {
'device_id': gpu,
# 'arena_extend_strategy': 'kNextPowerOfTwo',
'gpu_mem_limit': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
# 'cudnn_conv_algo_search': 'EXHAUSTIVE',
# 'do_copy_in_default_stream': True,
# ...
})] + providers
if 'TensorrtExecutionProvider' in providers:
providers.remove('TensorrtExecutionProvider')
if gpu >= 0:
providers = [('TensorrtExecutionProvider', {
'device_id': gpu,
'trt_max_workspace_size': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
# 'trt_fp16_enable': True,
# 'trt_engine_cache_enable': True,
# 'trt_timing_cache_enable': True,
# ...
})] + providers
model = ort.InferenceSession(
model_path,
providers=providers)
# FIXME: notify about selected provider/device
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[0].name
def predict_onnx(inputs):
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
inputs = inputs.astype(np.float32)
return model.run(
[output_name], {input_name: inputs})[0]
model.predict_on_batch = predict_onnx
model.input_shape = model.get_inputs()[0].shape
return model
def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel:
"""
Load OCR model
"""
model_dir = self.model_path('ocr', variant)
if variant == 'tr':
from transformers import VisionEncoderDecoderModel
import torch
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
assert isinstance(model, VisionEncoderDecoderModel)
device0 = torch.device('cpu')
if not device and torch.cuda.is_available():
device = 'GPU' # try
if device and ':' in device:
for spec in device.split(','):
cat, dev = spec.split(':')
if fnmatchcase('ocr', cat):
device = dev
break
if device and device.startswith('GPU'):
try:
device0 = torch.device('cuda', int(device[3:] or 0))
name = torch.cuda.get_device_name(device0)
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
except:
self.logger.exception("cannot configure GPU device")
device0 = torch.device('cpu')
if device0.type == 'cuda':
model.to(device0)
else:
self.logger.warning("no GPU device available")
return model
return self.load_model('ocr', model_variant=variant, device=device)
def __str__(self):
return tabulate(
[
[
spec.type,
spec.category,
spec.variant,
spec.help,
f'Yes, at {self.model_path(spec.category, spec.variant)}'
if self.model_path(spec.category, spec.variant).exists()
else f'No, download {spec.dist_url}',
# self.model_path(spec.category, spec.variant),
]
for spec in sorted(self.specs.specs, key=lambda x: x.dist_url)
],
headers=[
'Type',
'Category',
'Variant',
'Help',
'Used in',
'Installed',
],
tablefmt='github',
)
def shutdown(self):
"""
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
"""
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
for needle in list(self._loaded.keys()):
if isinstance(self._loaded[needle], Predictor):
self._loaded[needle].shutdown()