ModelZoo.load_model: no XLA compilation

This commit is contained in:
Robert Sachunsky 2026-05-19 02:08:14 +02:00
parent f329e10a80
commit 481c286da9

View file

@ -35,7 +35,7 @@ class EynollahModelZoo:
self._overrides = []
if model_overrides:
self.override_models(*model_overrides)
self._loaded: Dict[str, Predictor] = {}
self._loaded: Dict[str, Union[Predictor, AnyModel]] = {}
@property
def model_overrides(self):
@ -197,6 +197,7 @@ class EynollahModelZoo:
model_path, compile=False,
custom_objects=dict(PatchEncoder=PatchEncoder,
Patches=Patches))
model.make_predict_function()
assert isinstance(model, KerasModel)
model._name = model_category
if resized:
@ -206,7 +207,10 @@ class EynollahModelZoo:
model = wrap_layout_model_patched(model)
model._name = model_category + '_patched'
else:
model.jit_compile = True
# increases required VRAM, does not always work
# (depending on CUDA/libcudnn/TF version):
#model.jit_compile = True
pass
if model_category == 'ocr':
model = KerasModel(
@ -214,10 +218,9 @@ class EynollahModelZoo:
model.get_layer(name="dense2").output, # type: ignore
)
model.make_predict_function()
return model
def get(self, model_category: str) -> Predictor:
def get(self, model_category: str) -> Union[Predictor, AnyModel]:
if model_category not in self._loaded:
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
return self._loaded[model_category]