diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index c66c349..d028004 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -281,7 +281,19 @@ class EynollahModelZoo: self._configure_tf_device(model_category, device=device) model = tf.saved_model.load(model_path) model.predict_on_batch = model.serve - model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape) + spec = model.signatures['serving_default'] + # some models receive lots of additional/internal + # (unknown) captured inputs polluting .inputs + # TF>=2.16 has spec.function_type.flat_inputs + # this non-public API works: + # input_spec = spec.inputs[:len(spec._arg_keywords)] + # but perhaps this is most reliable: + input_spec = tf.nest.flatten(spec.structured_input_signature, True) + input_spec = [tuple(i.shape) for i in input_spec] + if len(input_spec) > 1: + model.input_shape = tuple(input_spec) + else: + model.input_shape = input_spec[0] return model