diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index f1d8824..054552a 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -35,7 +35,7 @@ class EynollahModelZoo: self._overrides = [] if model_overrides: self.override_models(*model_overrides) - self._loaded: Dict[str, Predictor] = {} + self._loaded: Dict[str, Union[Predictor, AnyModel]] = {} @property def model_overrides(self): @@ -197,6 +197,7 @@ class EynollahModelZoo: model_path, compile=False, custom_objects=dict(PatchEncoder=PatchEncoder, Patches=Patches)) + model.make_predict_function() assert isinstance(model, KerasModel) model._name = model_category if resized: @@ -206,7 +207,10 @@ class EynollahModelZoo: model = wrap_layout_model_patched(model) model._name = model_category + '_patched' else: - model.jit_compile = True + # increases required VRAM, does not always work + # (depending on CUDA/libcudnn/TF version): + #model.jit_compile = True + pass if model_category == 'ocr': model = KerasModel( @@ -214,10 +218,9 @@ class EynollahModelZoo: model.get_layer(name="dense2").output, # type: ignore ) - model.make_predict_function() return model - def get(self, model_category: str) -> Predictor: + def get(self, model_category: str) -> Union[Predictor, AnyModel]: if model_category not in self._loaded: raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"') return self._loaded[model_category]