ModelZoo: ensure exported TensorShape is converted to plain tuple

This commit is contained in:
Robert Sachunsky 2026-05-22 12:35:44 +02:00
parent 0836230c6b
commit 26afc5ddab

View file

@ -207,13 +207,14 @@ class EynollahModelZoo:
model_path = self.model_path(model_category, model_variant)
try:
if model_path.is_dir() and not (model_path / "keras_metadata.pb").exists():
# short-cut to avoid warning for exported models
raise ValueError()
model = load_model(model_path, compile=False)
model.make_predict_function()
except (AttributeError, ValueError):
model = tf.saved_model.load(model_path)
model.predict_on_batch = model.serve
model.input_shape = model.signatures.get('serving_default').inputs[0].shape
model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape)
model._name = model_category
if resized:
model = wrap_layout_model_resized(model)