mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
ModelZoo ONNX backend for inference: support multi-input or -output
This commit is contained in:
parent
38fe4d33ad
commit
27ca9733db
1 changed files with 18 additions and 7 deletions
|
|
@ -331,14 +331,25 @@ class EynollahModelZoo:
|
|||
model_path,
|
||||
providers=providers)
|
||||
# FIXME: notify about selected provider/device
|
||||
input_name = model.get_inputs()[0].name
|
||||
output_name = model.get_outputs()[0].name
|
||||
model_inputs = [model_input.name
|
||||
for model_input in model.get_inputs()]
|
||||
model_outputs = [model_output.name
|
||||
for model_output in model.get_outputs()]
|
||||
def predict_onnx(inputs):
|
||||
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
|
||||
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
|
||||
inputs = inputs.astype(np.float32)
|
||||
return model.run(
|
||||
[output_name], {input_name: inputs})[0]
|
||||
if len(model_inputs) == 1:
|
||||
inputs = [inputs]
|
||||
outputs = model.run(model_outputs, {
|
||||
model_input:
|
||||
input_data.astype(
|
||||
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
|
||||
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
|
||||
np.float32 if input_data.dtype in [np.float16, np.float64] else
|
||||
input_data.dtype)
|
||||
for model_input, input_data in zip(model_inputs, inputs)
|
||||
})
|
||||
if len(model_outputs) == 1:
|
||||
outputs = outputs[0]
|
||||
return outputs
|
||||
model.predict_on_batch = predict_onnx
|
||||
model.input_shape = model.get_inputs()[0].shape
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue