mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-31 01:59:27 +02:00
ModelZoo: support inference with ONNX/TensorRT…
- comment out ad-hoc conversion/loading of autosized models - refactor predictor backends for model types into separate functions - only attempt inference conversion of cnn-rnn-ocr model if applicable (`ctc_loss` layer still present) - apply VRAM limits across model types (Keras, TF-Serving, ONNX) - apply TF device selection across model types (Keras, TF-Serving) - implement predictor backend for ONNX models: - using onnxruntime - covering CUDA and TensorRT providers - trying to support manual device selection - hiding session management details - converting float32 to float16
This commit is contained in:
parent
f833a516e7
commit
13f2f81c45
1 changed files with 151 additions and 58 deletions
|
|
@ -14,6 +14,19 @@ from .default_specs import DEFAULT_MODEL_SPECS
|
||||||
from .types import AnyModel, T
|
from .types import AnyModel, T
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_VRAM_LIMITS = {
|
||||||
|
"binarization": 868, # due to bs 5
|
||||||
|
"enhancement": 980, # due to bs 3
|
||||||
|
"col_classifier": 210,
|
||||||
|
"page": 618,
|
||||||
|
"textline": 1680, # 954 for bs 1
|
||||||
|
"region_1_2": 1580,
|
||||||
|
"region_fl_np": 1756,
|
||||||
|
"table": 1818,
|
||||||
|
"reading_order": 632,
|
||||||
|
"ocr": 850,
|
||||||
|
}
|
||||||
|
|
||||||
class EynollahModelZoo:
|
class EynollahModelZoo:
|
||||||
"""
|
"""
|
||||||
Wrapper class that handles storage and loading of models for all eynollah runners.
|
Wrapper class that handles storage and loading of models for all eynollah runners.
|
||||||
|
|
@ -73,6 +86,10 @@ class EynollahModelZoo:
|
||||||
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
||||||
# prefer SavedModel over HDF5 format if it exists
|
# prefer SavedModel over HDF5 format if it exists
|
||||||
model_path = Path(model_path.stem)
|
model_path = Path(model_path.stem)
|
||||||
|
if model_path.with_suffix('.onnx').exists():
|
||||||
|
# prefer ONNX over SavedModel format if it exists
|
||||||
|
model_path = model_path.with_suffix('.onnx')
|
||||||
|
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
def load_models(
|
def load_models(
|
||||||
|
|
@ -136,20 +153,34 @@ class EynollahModelZoo:
|
||||||
"""
|
"""
|
||||||
Load any model
|
Load any model
|
||||||
"""
|
"""
|
||||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
if model_path_override:
|
||||||
|
self.override_models((model_category, model_variant, model_path_override))
|
||||||
|
model_path = self.model_path(model_category, model_variant)
|
||||||
|
|
||||||
|
if model_path.is_dir() and (model_path / "keras_metadata.pb").exists():
|
||||||
|
# Keras model
|
||||||
|
model = self._load_keras_model(model_category, model_path, device=device)
|
||||||
|
elif model_path.is_dir():
|
||||||
|
# TF-Serving model
|
||||||
|
model = self._load_serving_model(model_category, model_path, device=device)
|
||||||
|
elif model_path.suffix == '.onnx':
|
||||||
|
# ONNX model
|
||||||
|
model = self._load_onnx_model(model_category, model_path, device=device)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown model type for '%s'" % str(model_path))
|
||||||
|
model._name = model_category
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get(self, model_category: str) -> Union[Predictor, AnyModel]:
|
||||||
|
if model_category not in self._loaded:
|
||||||
|
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
|
||||||
|
return self._loaded[model_category]
|
||||||
|
|
||||||
|
def _configure_tf_device(self, model_category, device=''):
|
||||||
from ocrd_utils import tf_disable_interactive_logs
|
from ocrd_utils import tf_disable_interactive_logs
|
||||||
tf_disable_interactive_logs()
|
tf_disable_interactive_logs()
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.models import load_model
|
|
||||||
from tensorflow.keras.models import Model as KerasModel
|
|
||||||
|
|
||||||
from ..patch_encoder import (
|
|
||||||
PatchEncoder,
|
|
||||||
Patches,
|
|
||||||
wrap_layout_model_patched,
|
|
||||||
wrap_layout_model_resized,
|
|
||||||
)
|
|
||||||
cuda = False
|
cuda = False
|
||||||
try:
|
try:
|
||||||
gpus = tf.config.list_physical_devices('GPU')
|
gpus = tf.config.list_physical_devices('GPU')
|
||||||
|
|
@ -175,18 +206,8 @@ class EynollahModelZoo:
|
||||||
# (for small GPUs); so try hard (calibrated) limits instead:
|
# (for small GPUs); so try hard (calibrated) limits instead:
|
||||||
tf.config.set_logical_device_configuration(
|
tf.config.set_logical_device_configuration(
|
||||||
device,
|
device,
|
||||||
[tf.config.LogicalDeviceConfiguration(memory_limit={
|
[tf.config.LogicalDeviceConfiguration(
|
||||||
"binarization": 868, # due to bs 5
|
memory_limit=MODEL_VRAM_LIMITS[model_category])])
|
||||||
"enhancement": 980, # due to bs 3
|
|
||||||
"col_classifier": 210,
|
|
||||||
"page": 618,
|
|
||||||
"textline": 1680, # 954 for bs 1
|
|
||||||
"region_1_2": 1580,
|
|
||||||
"region_fl_np": 1756,
|
|
||||||
"table": 1818,
|
|
||||||
"reading_order": 632,
|
|
||||||
"ocr": 850,
|
|
||||||
}[model_category])])
|
|
||||||
vendor_name = (
|
vendor_name = (
|
||||||
tf.config.experimental.get_device_details(device)
|
tf.config.experimental.get_device_details(device)
|
||||||
.get('device_name', 'unknown'))
|
.get('device_name', 'unknown'))
|
||||||
|
|
@ -194,52 +215,124 @@ class EynollahModelZoo:
|
||||||
self.logger.info("using GPU %s (%s) for model %s",
|
self.logger.info("using GPU %s (%s) for model %s",
|
||||||
device.name,
|
device.name,
|
||||||
vendor_name,
|
vendor_name,
|
||||||
model_category + (
|
model_category # + (
|
||||||
"_patched" if patched else
|
# "_patched" if patched else
|
||||||
"_resized" if resized else ""))
|
# "_resized" if resized else "")
|
||||||
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.logger.exception("cannot configure GPU devices")
|
self.logger.exception("cannot configure GPU devices")
|
||||||
if not cuda:
|
if not cuda:
|
||||||
self.logger.warning("no GPU device available")
|
self.logger.warning("no GPU device available")
|
||||||
|
|
||||||
if model_path_override:
|
def _load_keras_model(self, model_category, model_path, device=''):
|
||||||
self.override_models((model_category, model_variant, model_path_override))
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
model_path = self.model_path(model_category, model_variant)
|
from ocrd_utils import tf_disable_interactive_logs
|
||||||
try:
|
tf_disable_interactive_logs()
|
||||||
if model_path.is_dir() and not (model_path / "keras_metadata.pb").exists():
|
|
||||||
# short-cut to avoid warning for exported models
|
from tensorflow.keras.models import load_model
|
||||||
raise ValueError()
|
from tensorflow.keras.models import Model as KerasModel
|
||||||
model = load_model(model_path, compile=False)
|
|
||||||
model.make_predict_function()
|
self._configure_tf_device(model_category, device=device)
|
||||||
except (AttributeError, ValueError):
|
|
||||||
model = tf.saved_model.load(model_path)
|
model = load_model(model_path, compile=False)
|
||||||
model.predict_on_batch = model.serve
|
|
||||||
model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape)
|
# from ..patch_encoder import (
|
||||||
model._name = model_category
|
# wrap_layout_model_patched,
|
||||||
if resized:
|
# wrap_layout_model_resized,
|
||||||
model = wrap_layout_model_resized(model)
|
# )
|
||||||
model._name = model_category + '_resized'
|
# if resized:
|
||||||
elif patched:
|
# model = wrap_layout_model_resized(model)
|
||||||
model = wrap_layout_model_patched(model)
|
# model._name = model_category + '_resized'
|
||||||
model._name = model_category + '_patched'
|
# elif patched:
|
||||||
else:
|
# model = wrap_layout_model_patched(model)
|
||||||
# increases required VRAM, does not always work
|
# model._name = model_category + '_patched'
|
||||||
# (depending on CUDA/libcudnn/TF version):
|
|
||||||
#model.jit_compile = True
|
|
||||||
pass
|
|
||||||
|
|
||||||
if model_category == 'ocr':
|
if model_category == 'ocr':
|
||||||
model = KerasModel(
|
# cnn-rnn-ocr task model may not be in inference mode, yet
|
||||||
model.get_layer(name="image").input, # type: ignore
|
try:
|
||||||
model.get_layer(name="dense2").output, # type: ignore
|
model.get_layer(name='ctc_loss')
|
||||||
)
|
except ValueError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model = KerasModel(
|
||||||
|
model.get_layer(name="image").input, # type: ignore
|
||||||
|
model.get_layer(name="dense2").output, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
model.make_predict_function()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get(self, model_category: str) -> Union[Predictor, AnyModel]:
|
def _load_serving_model(self, model_category, model_path, device=''):
|
||||||
if model_category not in self._loaded:
|
from ocrd_utils import tf_disable_interactive_logs
|
||||||
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
|
tf_disable_interactive_logs()
|
||||||
return self._loaded[model_category]
|
import tensorflow as tf
|
||||||
|
|
||||||
|
self._configure_tf_device(model_category, device=device)
|
||||||
|
model = tf.saved_model.load(model_path)
|
||||||
|
model.predict_on_batch = model.serve
|
||||||
|
model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load_onnx_model(self, model_category, model_path, device=''):
|
||||||
|
import onnxruntime as ort
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
providers = ort.get_available_providers()
|
||||||
|
if device:
|
||||||
|
if ':' in device:
|
||||||
|
for spec in device.split(','):
|
||||||
|
cat, dev = spec.split(':')
|
||||||
|
if fnmatchcase(model_category, cat):
|
||||||
|
device = dev
|
||||||
|
break
|
||||||
|
if device == 'CPU':
|
||||||
|
gpu = -1
|
||||||
|
else:
|
||||||
|
assert device.startswith('GPU')
|
||||||
|
gpu = int(device[3:] or "0")
|
||||||
|
else:
|
||||||
|
gpu = 0 # try first allowable
|
||||||
|
# configure and prioritise
|
||||||
|
if 'CUDAExecutionProvider' in providers:
|
||||||
|
providers.remove('CUDAExecutionProvider')
|
||||||
|
if gpu >= 0:
|
||||||
|
providers = [('CUDAExecutionProvider', {
|
||||||
|
'device_id': gpu,
|
||||||
|
# 'arena_extend_strategy': 'kNextPowerOfTwo',
|
||||||
|
'gpu_mem_limit': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
|
||||||
|
# 'cudnn_conv_algo_search': 'EXHAUSTIVE',
|
||||||
|
# 'do_copy_in_default_stream': True,
|
||||||
|
# ...
|
||||||
|
})] + providers
|
||||||
|
if 'TensorrtExecutionProvider' in providers:
|
||||||
|
providers.remove('TensorrtExecutionProvider')
|
||||||
|
if gpu >= 0:
|
||||||
|
providers = [('TensorrtExecutionProvider', {
|
||||||
|
'device_id': gpu,
|
||||||
|
'trt_max_workspace_size': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
|
||||||
|
# 'trt_fp16_enable': True,
|
||||||
|
# 'trt_engine_cache_enable': True,
|
||||||
|
# 'trt_timing_cache_enable': True,
|
||||||
|
# ...
|
||||||
|
})] + providers
|
||||||
|
model = ort.InferenceSession(
|
||||||
|
model_path,
|
||||||
|
providers=providers)
|
||||||
|
# FIXME: notify about selected provider/device
|
||||||
|
input_name = model.get_inputs()[0].name
|
||||||
|
output_name = model.get_outputs()[0].name
|
||||||
|
def predict_onnx(inputs):
|
||||||
|
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
|
||||||
|
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
|
||||||
|
inputs = inputs.astype(np.float32)
|
||||||
|
return model.run(
|
||||||
|
[output_name], {input_name: inputs})[0]
|
||||||
|
model.predict_on_batch = predict_onnx
|
||||||
|
model.input_shape = model.get_inputs()[0].shape
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel:
|
def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue