mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
ModelZoo TF-Serving backend: deal with buggy .inputs signature…
work around TF bug that adds captured/unknown inputs to function signature
This commit is contained in:
parent
45c92eada2
commit
94082bc64a
1 changed files with 13 additions and 1 deletions
|
|
@ -281,7 +281,19 @@ class EynollahModelZoo:
|
||||||
self._configure_tf_device(model_category, device=device)
|
self._configure_tf_device(model_category, device=device)
|
||||||
model = tf.saved_model.load(model_path)
|
model = tf.saved_model.load(model_path)
|
||||||
model.predict_on_batch = model.serve
|
model.predict_on_batch = model.serve
|
||||||
model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape)
|
spec = model.signatures['serving_default']
|
||||||
|
# some models receive lots of additional/internal
|
||||||
|
# (unknown) captured inputs polluting .inputs
|
||||||
|
# TF>=2.16 has spec.function_type.flat_inputs
|
||||||
|
# this non-public API works:
|
||||||
|
# input_spec = spec.inputs[:len(spec._arg_keywords)]
|
||||||
|
# but perhaps this is most reliable:
|
||||||
|
input_spec = tf.nest.flatten(spec.structured_input_signature, True)
|
||||||
|
input_spec = [tuple(i.shape) for i in input_spec]
|
||||||
|
if len(input_spec) > 1:
|
||||||
|
model.input_shape = tuple(input_spec)
|
||||||
|
else:
|
||||||
|
model.input_shape = input_spec[0]
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue