mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
model_zoo: fix clash between Predictor and direct (OCR) use-cases…
- `load_models()`: uniformly handle arg types - `load_model()`: move handling of non-model categories to `load_models()` - `load_model()`: move SavedModel preference over HDF5 to `model_path()` - `_load_ocr_model()`: add user-selected device handling and reporting for Torch (as for TF) - `_load_ocr_model()`: move (TF-based) CNN-RNN case to `load_model()` (including Keras layer mapping) - `shutdown()`: only apply `shutdown()` to Predictor model types
This commit is contained in:
parent
98e6fbbcbb
commit
ded668a256
1 changed files with 87 additions and 57 deletions
|
|
@ -70,6 +70,9 @@ class EynollahModelZoo:
|
||||||
model_path = Path(self.model_basedir).joinpath(spec.filename)
|
model_path = Path(self.model_basedir).joinpath(spec.filename)
|
||||||
else:
|
else:
|
||||||
model_path = Path(spec.filename)
|
model_path = Path(spec.filename)
|
||||||
|
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
||||||
|
# prefer SavedModel over HDF5 format if it exists
|
||||||
|
model_path = Path(model_path.stem)
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
def load_models(
|
def load_models(
|
||||||
|
|
@ -82,20 +85,42 @@ class EynollahModelZoo:
|
||||||
"""
|
"""
|
||||||
ret = {} # cannot use self._loaded here, yet – first spawn all predictors
|
ret = {} # cannot use self._loaded here, yet – first spawn all predictors
|
||||||
for load_args in all_load_args:
|
for load_args in all_load_args:
|
||||||
|
load_kwargs = dict(device=device)
|
||||||
if isinstance(load_args, str):
|
if isinstance(load_args, str):
|
||||||
model_category = load_args
|
model_category, model_variant = load_args, ""
|
||||||
load_args = [model_category]
|
elif len(load_args) > 2:
|
||||||
|
# for calls to self.model_path
|
||||||
|
self.override_models(load_args)
|
||||||
|
# for calls to Predictor.load_model
|
||||||
|
model_category, model_variant, model_path = load_args
|
||||||
|
load_kwargs["model_variant"] = model_variant
|
||||||
|
load_kwargs["model_path_override"] = model_path
|
||||||
else:
|
else:
|
||||||
model_category = load_args[0]
|
model_category, model_variant = load_args
|
||||||
load_kwargs = {}
|
load_kwargs["model_variant"] = model_variant
|
||||||
|
|
||||||
if model_category.endswith('_resized'):
|
if model_category.endswith('_resized'):
|
||||||
load_args[0] = model_category[:-8]
|
model_category = model_category[:-8]
|
||||||
load_kwargs["resized"] = True
|
load_kwargs["resized"] = True
|
||||||
elif model_category.endswith('_patched'):
|
elif model_category.endswith('_patched'):
|
||||||
load_args[0] = model_category[:-8]
|
model_category = model_category[:-8]
|
||||||
load_kwargs["patched"] = True
|
load_kwargs["patched"] = True
|
||||||
ret[model_category] = Predictor(self.logger, self)
|
|
||||||
ret[model_category].load_model(*load_args, **load_kwargs, device=device)
|
if model_category == 'ocr':
|
||||||
|
model = self._load_ocr_model(variant=model_variant, device=device)
|
||||||
|
elif model_category == 'num_to_char':
|
||||||
|
model = self._load_num_to_char()
|
||||||
|
elif model_category == 'characters':
|
||||||
|
model = self._load_characters()
|
||||||
|
elif model_category == 'trocr_processor':
|
||||||
|
from transformers import TrOCRProcessor
|
||||||
|
model_path = self.model_path(model_category, model_variant)
|
||||||
|
model = TrOCRProcessor.from_pretrained(model_path)
|
||||||
|
else:
|
||||||
|
model = Predictor(self.logger, self)
|
||||||
|
model.load_model(model_category, **load_kwargs)
|
||||||
|
|
||||||
|
ret[model_category] = model
|
||||||
self._loaded.update(ret)
|
self._loaded.update(ret)
|
||||||
return self._loaded
|
return self._loaded
|
||||||
|
|
||||||
|
|
@ -117,6 +142,7 @@ class EynollahModelZoo:
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
|
from tensorflow.keras.models import Model as KerasModel
|
||||||
|
|
||||||
from ..patch_encoder import (
|
from ..patch_encoder import (
|
||||||
PatchEncoder,
|
PatchEncoder,
|
||||||
|
|
@ -162,19 +188,6 @@ class EynollahModelZoo:
|
||||||
if model_path_override:
|
if model_path_override:
|
||||||
self.override_models((model_category, model_variant, model_path_override))
|
self.override_models((model_category, model_variant, model_path_override))
|
||||||
model_path = self.model_path(model_category, model_variant)
|
model_path = self.model_path(model_category, model_variant)
|
||||||
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
|
||||||
# prefer SavedModel over HDF5 format if it exists
|
|
||||||
model_path = Path(model_path.stem)
|
|
||||||
if model_category == 'ocr':
|
|
||||||
model = self._load_ocr_model(variant=model_variant)
|
|
||||||
elif model_category == 'num_to_char':
|
|
||||||
model = self._load_num_to_char()
|
|
||||||
elif model_category == 'characters':
|
|
||||||
model = self._load_characters()
|
|
||||||
elif model_category == 'trocr_processor':
|
|
||||||
from transformers import TrOCRProcessor
|
|
||||||
model = TrOCRProcessor.from_pretrained(model_path)
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
# avoid wasting VRAM on non-transformer models
|
# avoid wasting VRAM on non-transformer models
|
||||||
model = load_model(model_path, compile=False)
|
model = load_model(model_path, compile=False)
|
||||||
|
|
@ -184,6 +197,7 @@ class EynollahModelZoo:
|
||||||
model_path, compile=False,
|
model_path, compile=False,
|
||||||
custom_objects=dict(PatchEncoder=PatchEncoder,
|
custom_objects=dict(PatchEncoder=PatchEncoder,
|
||||||
Patches=Patches))
|
Patches=Patches))
|
||||||
|
assert isinstance(model, KerasModel)
|
||||||
model._name = model_category
|
model._name = model_category
|
||||||
if resized:
|
if resized:
|
||||||
model = wrap_layout_model_resized(model)
|
model = wrap_layout_model_resized(model)
|
||||||
|
|
@ -193,6 +207,13 @@ class EynollahModelZoo:
|
||||||
model._name = model_category + '_patched'
|
model._name = model_category + '_patched'
|
||||||
else:
|
else:
|
||||||
model.jit_compile = True
|
model.jit_compile = True
|
||||||
|
|
||||||
|
if model_category == 'ocr':
|
||||||
|
model = KerasModel(
|
||||||
|
model.get_layer(name="image").input, # type: ignore
|
||||||
|
model.get_layer(name="dense2").output, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
model.make_predict_function()
|
model.make_predict_function()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
@ -201,26 +222,34 @@ class EynollahModelZoo:
|
||||||
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
|
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
|
||||||
return self._loaded[model_category]
|
return self._loaded[model_category]
|
||||||
|
|
||||||
def _load_ocr_model(self, variant: str) -> AnyModel:
|
def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel:
|
||||||
"""
|
"""
|
||||||
Load OCR model
|
Load OCR model
|
||||||
"""
|
"""
|
||||||
from tensorflow.keras.models import Model as KerasModel
|
model_dir = self.model_path('ocr', variant)
|
||||||
from tensorflow.keras.models import load_model
|
|
||||||
|
|
||||||
ocr_model_dir = self.model_path('ocr', variant)
|
|
||||||
if variant == 'tr':
|
if variant == 'tr':
|
||||||
from transformers import VisionEncoderDecoderModel
|
from transformers import VisionEncoderDecoderModel
|
||||||
ret = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
|
import torch
|
||||||
|
ret = VisionEncoderDecoderModel.from_pretrained(model_dir)
|
||||||
assert isinstance(ret, VisionEncoderDecoderModel)
|
assert isinstance(ret, VisionEncoderDecoderModel)
|
||||||
return ret
|
dev = torch.device('cpu')
|
||||||
|
if not device and torch.cuda.is_available():
|
||||||
|
device = 'GPU' # try
|
||||||
|
if device and device.startswith('GPU'):
|
||||||
|
try:
|
||||||
|
dev = torch.device('cuda', int(device[3:] or 0))
|
||||||
|
name = torch.cuda.get_device_name(dev)
|
||||||
|
self.logger.info("using GPU %s (%s) for model ocr:tr", dev, name)
|
||||||
|
except:
|
||||||
|
self.logger.exception("cannot configure GPU device")
|
||||||
|
dev = torch.device('cpu')
|
||||||
|
if dev.type == 'cuda':
|
||||||
|
ret.to(dev)
|
||||||
else:
|
else:
|
||||||
ocr_model = load_model(ocr_model_dir, compile=False)
|
self.logger.warning("no GPU device available")
|
||||||
assert isinstance(ocr_model, KerasModel)
|
return ret
|
||||||
return KerasModel(
|
|
||||||
ocr_model.get_layer(name="image").input, # type: ignore
|
return self.load_model('ocr', model_variant=variant, device=device)
|
||||||
ocr_model.get_layer(name="dense2").output, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def _load_characters(self) -> List[str]:
|
def _load_characters(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -273,5 +302,6 @@ class EynollahModelZoo:
|
||||||
"""
|
"""
|
||||||
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
||||||
for needle in list(self._loaded.keys()):
|
for needle in list(self._loaded.keys()):
|
||||||
|
if isinstance(self._loaded[needle], Predictor):
|
||||||
self._loaded[needle].shutdown()
|
self._loaded[needle].shutdown()
|
||||||
del self._loaded[needle]
|
del self._loaded[needle]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue