nbigaouette/onnxruntime-rs

Model with dynamic dimensions

bminixhofer opened this issue · 9 comments

Hi, thanks for this library!

I've been trying to run a model with dynamic input dimensions, but it doesn't work due to a NonMatchingDimensions error.
model here.

Here's how I'd use the model from the onnxruntime Python bindings:

import onnxruntime # v1.2.0

session = onnxruntime.InferenceSession("model.onnx")
outputs = session.run(None, {"input_ids": [[1, 2, 3]], "attention_mask": [[1, 1, 1]]})[0]
print(outputs.shape)

Using your Rust bindings:

let env = Environment::builder().with_name("env").build()?;
let session = env
            .new_session_builder()?
            .with_optimization_level(GraphOptimizationLevel::Basic)?
            .with_model_from_file("model.onnx")?;

println!("{:#?}", session.inputs);
println!("{:#?}", session.outputs);

let input_ids = Array2::<f32>::from_shape_vec((1, 3), vec![1f32, 2f32, 3f32])?;
let attention_mask = Array2::<f32>::from_shape_vec((1, 3), vec![1f32, 1f32, 1f32])?;

let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(vec![input_ids, attention_mask])?;

This prints:

[
    Input {
        name: "input_ids",
        input_type: Int64,
        dimensions: [
            4294967295,
            4294967295,
        ],
    },
    Input {
        name: "attention_mask",
        input_type: Int64,
        dimensions: [
            4294967295,
            4294967295,
        ],
    },
]
[
    Output {
        name: "output",
        output_type: Float,
        dimensions: [
            4294967295,
            4294967295,
            94,
        ],
    },
]
thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: NonMatchingDimensions { input: [1, 3], model: [4294967295, 4294967295] }'

Apparently the dynamic dimensions lead to an integer overflow (they are encoded as -1 in ONNX iirc).
I'm also a bit skeptical about the constraint on .run to have the same output type as input type - does that handle models with int64 input and float output correctly?

I appreciate any help!

Thanks for trying this out! Sorry for the delay, I moved recently and don't have as much free time as I would like.

I tried it and see the same problem. I'll see if I can do something about it.

Thanks, no problem. I'm using the onnxruntime Python bindings with PyO3 for now so this isn't blocking for me :)

I am taking a look at this and I have some questions.

If I print the inputs from python:

for idx, inputs in enumerate(session.get_inputs()):
    print("idx:", idx)
    print("   Name:", inputs.name)
    print("   Shape:", inputs.shape)

I get this:

idx: 0
   Name: input_ids
   Shape: ['batch', 'seq']
idx: 1
   Name: attention_mask
   Shape: ['batch', 'seq']

I guess I can assume that the python lib interprets the -1 on the onnx file for the dimension as being the string batch for the first dimension and and the string seq for the second? If I open the onnx file in Netron I see the same strings batch and seq. Is that a convention? Where does those strings come from? It's kind of hard to search for these strings :D

Investigating this, I found a bug related to the inputs.

Multiple inputs, like in your example, should be treated as inputs to the model that should be used together, not as a list to loop over to perform inference.

So this loop is wrong: https://github.com/nbigaouette/onnxruntime-rs/blob/751d996/onnxruntime/src/session.rs#L355-L404

Thanks!! :D

Can you take a look at branch issue22 (PR #23)?

It needs cleanup but it is able to load your model and seems to correctly perform inference with it. Feel free to comment there!

Hi, sorry, was a bit busy.

I assume the questions from your first comment are resolved? If not, there's more info here and in the onnx.proto file, but I don't know where exactly the names are stored either.

Thanks for the fix, I'll have a look!

I ran into another issue trying to set up the branch:

Cargo.toml

[package]
name = "onnxruntimerstest"
version = "0.1.0"
authors = ["Benjamin Minixhofer <bminixhofer@gmail.com>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
onnxruntime = { git = "https://github.com/nbigaouette/onnxruntime-rs", branch = "issue22" }

cargo build

   Compiling onnxruntime-sys v0.0.8 (https://github.com/nbigaouette/onnxruntime-rs?branch=issue22#300d616d)
error: failed to run custom build command for `onnxruntime-sys v0.0.8 (https://github.com/nbigaouette/onnxruntime-rs?branch=issue22#300d616d)`

Caused by:
  process didn't exit successfully: `/home/bminixhofer/Documents/Experiments/onnxruntimerstest/target/debug/build/onnxruntime-sys-c85e57555e099ab4/build-script-build` (exit code: 101)
  --- stdout
  strategy: "unknown"
  Creating directory "/home/bminixhofer/Documents/Experiments/onnxruntimerstest/target/debug/build/onnxruntime-sys-140fcf9a1cff8e0c/out"
  Downloading https://github.com/microsoft/onnxruntime/releases/download/v1.4.0/onnxruntime-linux-x64-1.4.0.tgz into /home/bminixhofer/Documents/Experiments/onnxruntimerstest/target/debug/build/onnxruntime-sys-140fcf9a1cff8e0c/out/onnxruntime-linux-x64-1.4.0.tgz

  --- stderr
  thread 'main' panicked at 'ERROR: Failed to download https://github.com/microsoft/onnxruntime/releases/download/v1.4.0/onnxruntime-linux-x64-1.4.0.tgz: Response[status: 404, status_text: Not Found]', /home/bminixhofer/.cargo/git/checkouts/onnxruntime-rs-11c88f9192edf7c4/300d616/onnxruntime-sys/build.rs:86:9
  note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

The same seems to be happening with version 0.0.8:

Cargo.toml

[package]
name = "onnxruntimerstest"
version = "0.1.0"
authors = ["Benjamin Minixhofer <bminixhofer@gmail.com>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
onnxruntime = "0.0.8"

Maybe something to do with onnxruntime v1.5.1 (released 6 days ago) ?

Yeah I don't understand why this fails. I have the same issue in the branch's CI. Downloading the archive has been working for some time, but then I get a 404. It might be caused by the new release as you suggested, but it's weird.

You can try to download the archive manually and extract it somewhere. Then use ORT_ENV_STRATEGY=system and ORT_ENV_SYSTEM_LIB_LOCATION (see https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime-sys/build.rs#L14-L23 ) to point to the extraction folder. I'll see if I can fix this problem.

Regarding the 404 failure, it was a bug in ureq used to download archives. See PR #24 and in addition to algesten/ureq#179.