From ded668a2562d2dc59646554a06338303cf2a6034 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Tue, 12 May 2026 18:17:43 +0200 Subject: [PATCH] =?UTF-8?q?model=5Fzoo:=20fix=20clash=20between=20Predicto?= =?UTF-8?q?r=20and=20direct=20(OCR)=20use-cases=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `load_models()`: uniformly handle arg types - `load_model()`: move handling of non-model categories to `load_models()` - `load_model()`: move SavedModel preference over HDF5 to `model_path()` - `_load_ocr_model()`: add user-selected device handling and reporting for Torch (as for TF) - `_load_ocr_model()`: move (TF-based) CNN-RNN case to `load_model()` (including Keras layer mapping) - `shutdown()`: only apply `shutdown()` to Predictor model types --- src/eynollah/model_zoo/model_zoo.py | 144 +++++++++++++++++----------- 1 file changed, 87 insertions(+), 57 deletions(-) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index fffd389..7f3cd6c 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -70,6 +70,9 @@ class EynollahModelZoo: model_path = Path(self.model_basedir).joinpath(spec.filename) else: model_path = Path(spec.filename) + if model_path.suffix == '.h5' and Path(model_path.stem).exists(): + # prefer SavedModel over HDF5 format if it exists + model_path = Path(model_path.stem) return model_path def load_models( @@ -82,28 +85,50 @@ class EynollahModelZoo: """ ret = {} # cannot use self._loaded here, yet – first spawn all predictors for load_args in all_load_args: + load_kwargs = dict(device=device) if isinstance(load_args, str): - model_category = load_args - load_args = [model_category] + model_category, model_variant = load_args, "" + elif len(load_args) > 2: + # for calls to self.model_path + self.override_models(load_args) + # for calls to Predictor.load_model + model_category, model_variant, model_path = load_args + load_kwargs["model_variant"] = model_variant + load_kwargs["model_path_override"] = model_path else: - model_category = load_args[0] - load_kwargs = {} + model_category, model_variant = load_args + load_kwargs["model_variant"] = model_variant + if model_category.endswith('_resized'): - load_args[0] = model_category[:-8] + model_category = model_category[:-8] load_kwargs["resized"] = True elif model_category.endswith('_patched'): - load_args[0] = model_category[:-8] + model_category = model_category[:-8] load_kwargs["patched"] = True - ret[model_category] = Predictor(self.logger, self) - ret[model_category].load_model(*load_args, **load_kwargs, device=device) + + if model_category == 'ocr': + model = self._load_ocr_model(variant=model_variant, device=device) + elif model_category == 'num_to_char': + model = self._load_num_to_char() + elif model_category == 'characters': + model = self._load_characters() + elif model_category == 'trocr_processor': + from transformers import TrOCRProcessor + model_path = self.model_path(model_category, model_variant) + model = TrOCRProcessor.from_pretrained(model_path) + else: + model = Predictor(self.logger, self) + model.load_model(model_category, **load_kwargs) + + ret[model_category] = model self._loaded.update(ret) return self._loaded def load_model( - self, - model_category: str, - model_variant: str = '', - model_path_override: Optional[str] = None, + self, + model_category: str, + model_variant: str = '', + model_path_override: Optional[str] = None, patched: bool = False, resized: bool = False, device: str = '', @@ -117,6 +142,7 @@ class EynollahModelZoo: import tensorflow as tf from tensorflow.keras.models import load_model + from tensorflow.keras.models import Model as KerasModel from ..patch_encoder import ( PatchEncoder, @@ -162,38 +188,33 @@ class EynollahModelZoo: if model_path_override: self.override_models((model_category, model_variant, model_path_override)) model_path = self.model_path(model_category, model_variant) - if model_path.suffix == '.h5' and Path(model_path.stem).exists(): - # prefer SavedModel over HDF5 format if it exists - model_path = Path(model_path.stem) - if model_category == 'ocr': - model = self._load_ocr_model(variant=model_variant) - elif model_category == 'num_to_char': - model = self._load_num_to_char() - elif model_category == 'characters': - model = self._load_characters() - elif model_category == 'trocr_processor': - from transformers import TrOCRProcessor - model = TrOCRProcessor.from_pretrained(model_path) + try: + # avoid wasting VRAM on non-transformer models + model = load_model(model_path, compile=False) + except Exception as e: + self.logger.error(e) + model = load_model( + model_path, compile=False, + custom_objects=dict(PatchEncoder=PatchEncoder, + Patches=Patches)) + assert isinstance(model, KerasModel) + model._name = model_category + if resized: + model = wrap_layout_model_resized(model) + model._name = model_category + '_resized' + elif patched: + model = wrap_layout_model_patched(model) + model._name = model_category + '_patched' else: - try: - # avoid wasting VRAM on non-transformer models - model = load_model(model_path, compile=False) - except Exception as e: - self.logger.error(e) - model = load_model( - model_path, compile=False, - custom_objects=dict(PatchEncoder=PatchEncoder, - Patches=Patches)) - model._name = model_category - if resized: - model = wrap_layout_model_resized(model) - model._name = model_category + '_resized' - elif patched: - model = wrap_layout_model_patched(model) - model._name = model_category + '_patched' - else: - model.jit_compile = True - model.make_predict_function() + model.jit_compile = True + + if model_category == 'ocr': + model = KerasModel( + model.get_layer(name="image").input, # type: ignore + model.get_layer(name="dense2").output, # type: ignore + ) + + model.make_predict_function() return model def get(self, model_category: str) -> Predictor: @@ -201,26 +222,34 @@ class EynollahModelZoo: raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"') return self._loaded[model_category] - def _load_ocr_model(self, variant: str) -> AnyModel: + def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel: """ Load OCR model """ - from tensorflow.keras.models import Model as KerasModel - from tensorflow.keras.models import load_model - - ocr_model_dir = self.model_path('ocr', variant) + model_dir = self.model_path('ocr', variant) if variant == 'tr': from transformers import VisionEncoderDecoderModel - ret = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) + import torch + ret = VisionEncoderDecoderModel.from_pretrained(model_dir) assert isinstance(ret, VisionEncoderDecoderModel) + dev = torch.device('cpu') + if not device and torch.cuda.is_available(): + device = 'GPU' # try + if device and device.startswith('GPU'): + try: + dev = torch.device('cuda', int(device[3:] or 0)) + name = torch.cuda.get_device_name(dev) + self.logger.info("using GPU %s (%s) for model ocr:tr", dev, name) + except: + self.logger.exception("cannot configure GPU device") + dev = torch.device('cpu') + if dev.type == 'cuda': + ret.to(dev) + else: + self.logger.warning("no GPU device available") return ret - else: - ocr_model = load_model(ocr_model_dir, compile=False) - assert isinstance(ocr_model, KerasModel) - return KerasModel( - ocr_model.get_layer(name="image").input, # type: ignore - ocr_model.get_layer(name="dense2").output, # type: ignore - ) + + return self.load_model('ocr', model_variant=variant, device=device) def _load_characters(self) -> List[str]: """ @@ -273,5 +302,6 @@ class EynollahModelZoo: """ if hasattr(self, '_loaded') and getattr(self, '_loaded'): for needle in list(self._loaded.keys()): - self._loaded[needle].shutdown() + if isinstance(self._loaded[needle], Predictor): + self._loaded[needle].shutdown() del self._loaded[needle]