ModelZoo ONNX backend for inference: support multi-input or -output

This commit is contained in:
Robert Sachunsky 2026-06-03 20:57:02 +02:00
parent 38fe4d33ad
commit 27ca9733db

View file

@ -331,14 +331,25 @@ class EynollahModelZoo:
model_path,
providers=providers)
# FIXME: notify about selected provider/device
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[0].name
model_inputs = [model_input.name
for model_input in model.get_inputs()]
model_outputs = [model_output.name
for model_output in model.get_outputs()]
def predict_onnx(inputs):
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
inputs = inputs.astype(np.float32)
return model.run(
[output_name], {input_name: inputs})[0]
if len(model_inputs) == 1:
inputs = [inputs]
outputs = model.run(model_outputs, {
model_input:
input_data.astype(
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
np.float32 if input_data.dtype in [np.float16, np.float64] else
input_data.dtype)
for model_input, input_data in zip(model_inputs, inputs)
})
if len(model_outputs) == 1:
outputs = outputs[0]
return outputs
model.predict_on_batch = predict_onnx
model.input_shape = model.get_inputs()[0].shape