mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
ModelZoo ONNX backend for inference: support multi-input or -output
This commit is contained in:
parent
38fe4d33ad
commit
27ca9733db
1 changed files with 18 additions and 7 deletions
|
|
@ -331,14 +331,25 @@ class EynollahModelZoo:
|
||||||
model_path,
|
model_path,
|
||||||
providers=providers)
|
providers=providers)
|
||||||
# FIXME: notify about selected provider/device
|
# FIXME: notify about selected provider/device
|
||||||
input_name = model.get_inputs()[0].name
|
model_inputs = [model_input.name
|
||||||
output_name = model.get_outputs()[0].name
|
for model_input in model.get_inputs()]
|
||||||
|
model_outputs = [model_output.name
|
||||||
|
for model_output in model.get_outputs()]
|
||||||
def predict_onnx(inputs):
|
def predict_onnx(inputs):
|
||||||
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
|
if len(model_inputs) == 1:
|
||||||
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
|
inputs = [inputs]
|
||||||
inputs = inputs.astype(np.float32)
|
outputs = model.run(model_outputs, {
|
||||||
return model.run(
|
model_input:
|
||||||
[output_name], {input_name: inputs})[0]
|
input_data.astype(
|
||||||
|
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
|
||||||
|
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
|
||||||
|
np.float32 if input_data.dtype in [np.float16, np.float64] else
|
||||||
|
input_data.dtype)
|
||||||
|
for model_input, input_data in zip(model_inputs, inputs)
|
||||||
|
})
|
||||||
|
if len(model_outputs) == 1:
|
||||||
|
outputs = outputs[0]
|
||||||
|
return outputs
|
||||||
model.predict_on_batch = predict_onnx
|
model.predict_on_batch = predict_onnx
|
||||||
model.input_shape = model.get_inputs()[0].shape
|
model.input_shape = model.get_inputs()[0].shape
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue