Model with dynamic dimensions
bminixhofer opened this issue · 9 comments
Hi, thanks for this library!
I've been trying to run a model with dynamic input dimensions, but it doesn't work due to a NonMatchingDimensions
error.
model here.
Here's how I'd use the model from the onnxruntime Python bindings:
import onnxruntime # v1.2.0
session = onnxruntime.InferenceSession("model.onnx")
outputs = session.run(None, {"input_ids": [[1, 2, 3]], "attention_mask": [[1, 1, 1]]})[0]
print(outputs.shape)
Using your Rust bindings:
let env = Environment::builder().with_name("env").build()?;
let session = env
.new_session_builder()?
.with_optimization_level(GraphOptimizationLevel::Basic)?
.with_model_from_file("model.onnx")?;
println!("{:#?}", session.inputs);
println!("{:#?}", session.outputs);
let input_ids = Array2::<f32>::from_shape_vec((1, 3), vec![1f32, 2f32, 3f32])?;
let attention_mask = Array2::<f32>::from_shape_vec((1, 3), vec![1f32, 1f32, 1f32])?;
let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(vec![input_ids, attention_mask])?;
This prints:
[
Input {
name: "input_ids",
input_type: Int64,
dimensions: [
4294967295,
4294967295,
],
},
Input {
name: "attention_mask",
input_type: Int64,
dimensions: [
4294967295,
4294967295,
],
},
]
[
Output {
name: "output",
output_type: Float,
dimensions: [
4294967295,
4294967295,
94,
],
},
]
thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: NonMatchingDimensions { input: [1, 3], model: [4294967295, 4294967295] }'
Apparently the dynamic dimensions lead to an integer overflow (they are encoded as -1 in ONNX iirc).
I'm also a bit skeptical about the constraint on .run
to have the same output type as input type - does that handle models with int64 input and float output correctly?
I appreciate any help!
Thanks for trying this out! Sorry for the delay, I moved recently and don't have as much free time as I would like.
I tried it and see the same problem. I'll see if I can do something about it.
Thanks, no problem. I'm using the onnxruntime Python bindings with PyO3 for now so this isn't blocking for me :)
I am taking a look at this and I have some questions.
If I print the inputs from python:
for idx, inputs in enumerate(session.get_inputs()):
print("idx:", idx)
print(" Name:", inputs.name)
print(" Shape:", inputs.shape)
I get this:
idx: 0
Name: input_ids
Shape: ['batch', 'seq']
idx: 1
Name: attention_mask
Shape: ['batch', 'seq']
I guess I can assume that the python lib interprets the -1
on the onnx file for the dimension as being the string batch
for the first dimension and and the string seq
for the second? If I open the onnx file in Netron I see the same strings batch
and seq
. Is that a convention? Where does those strings come from? It's kind of hard to search for these strings :D
Investigating this, I found a bug related to the inputs.
Multiple inputs, like in your example, should be treated as inputs to the model that should be used together, not as a list to loop over to perform inference.
So this loop is wrong: https://github.com/nbigaouette/onnxruntime-rs/blob/751d996/onnxruntime/src/session.rs#L355-L404
Thanks!! :D
Can you take a look at branch issue22
(PR #23)?
It needs cleanup but it is able to load your model and seems to correctly perform inference with it. Feel free to comment there!
Hi, sorry, was a bit busy.
I assume the questions from your first comment are resolved? If not, there's more info here and in the onnx.proto file, but I don't know where exactly the names are stored either.
Thanks for the fix, I'll have a look!
I ran into another issue trying to set up the branch:
Cargo.toml
[package]
name = "onnxruntimerstest"
version = "0.1.0"
authors = ["Benjamin Minixhofer <bminixhofer@gmail.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
onnxruntime = { git = "https://github.com/nbigaouette/onnxruntime-rs", branch = "issue22" }
cargo build
Compiling onnxruntime-sys v0.0.8 (https://github.com/nbigaouette/onnxruntime-rs?branch=issue22#300d616d)
error: failed to run custom build command for `onnxruntime-sys v0.0.8 (https://github.com/nbigaouette/onnxruntime-rs?branch=issue22#300d616d)`
Caused by:
process didn't exit successfully: `/home/bminixhofer/Documents/Experiments/onnxruntimerstest/target/debug/build/onnxruntime-sys-c85e57555e099ab4/build-script-build` (exit code: 101)
--- stdout
strategy: "unknown"
Creating directory "/home/bminixhofer/Documents/Experiments/onnxruntimerstest/target/debug/build/onnxruntime-sys-140fcf9a1cff8e0c/out"
Downloading https://github.com/microsoft/onnxruntime/releases/download/v1.4.0/onnxruntime-linux-x64-1.4.0.tgz into /home/bminixhofer/Documents/Experiments/onnxruntimerstest/target/debug/build/onnxruntime-sys-140fcf9a1cff8e0c/out/onnxruntime-linux-x64-1.4.0.tgz
--- stderr
thread 'main' panicked at 'ERROR: Failed to download https://github.com/microsoft/onnxruntime/releases/download/v1.4.0/onnxruntime-linux-x64-1.4.0.tgz: Response[status: 404, status_text: Not Found]', /home/bminixhofer/.cargo/git/checkouts/onnxruntime-rs-11c88f9192edf7c4/300d616/onnxruntime-sys/build.rs:86:9
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
The same seems to be happening with version 0.0.8:
Cargo.toml
[package]
name = "onnxruntimerstest"
version = "0.1.0"
authors = ["Benjamin Minixhofer <bminixhofer@gmail.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
onnxruntime = "0.0.8"
Maybe something to do with onnxruntime v1.5.1 (released 6 days ago) ?
Yeah I don't understand why this fails. I have the same issue in the branch's CI. Downloading the archive has been working for some time, but then I get a 404. It might be caused by the new release as you suggested, but it's weird.
You can try to download the archive manually and extract it somewhere. Then use ORT_ENV_STRATEGY=system
and ORT_ENV_SYSTEM_LIB_LOCATION
(see https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime-sys/build.rs#L14-L23 ) to point to the extraction folder. I'll see if I can fix this problem.
Regarding the 404 failure, it was a bug in ureq
used to download archives. See PR #24 and in addition to algesten/ureq#179.