ModelZoo.load_model: no XLA compilation

This commit is contained in:
Robert Sachunsky 2026-05-19 02:08:14 +02:00
parent f329e10a80
commit 481c286da9

View file

@ -35,7 +35,7 @@ class EynollahModelZoo:
self._overrides = [] self._overrides = []
if model_overrides: if model_overrides:
self.override_models(*model_overrides) self.override_models(*model_overrides)
self._loaded: Dict[str, Predictor] = {} self._loaded: Dict[str, Union[Predictor, AnyModel]] = {}
@property @property
def model_overrides(self): def model_overrides(self):
@ -197,6 +197,7 @@ class EynollahModelZoo:
model_path, compile=False, model_path, compile=False,
custom_objects=dict(PatchEncoder=PatchEncoder, custom_objects=dict(PatchEncoder=PatchEncoder,
Patches=Patches)) Patches=Patches))
model.make_predict_function()
assert isinstance(model, KerasModel) assert isinstance(model, KerasModel)
model._name = model_category model._name = model_category
if resized: if resized:
@ -206,7 +207,10 @@ class EynollahModelZoo:
model = wrap_layout_model_patched(model) model = wrap_layout_model_patched(model)
model._name = model_category + '_patched' model._name = model_category + '_patched'
else: else:
model.jit_compile = True # increases required VRAM, does not always work
# (depending on CUDA/libcudnn/TF version):
#model.jit_compile = True
pass
if model_category == 'ocr': if model_category == 'ocr':
model = KerasModel( model = KerasModel(
@ -214,10 +218,9 @@ class EynollahModelZoo:
model.get_layer(name="dense2").output, # type: ignore model.get_layer(name="dense2").output, # type: ignore
) )
model.make_predict_function()
return model return model
def get(self, model_category: str) -> Predictor: def get(self, model_category: str) -> Union[Predictor, AnyModel]:
if model_category not in self._loaded: if model_category not in self._loaded:
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"') raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
return self._loaded[model_category] return self._loaded[model_category]