ankane/onnxruntime-ruby

When input is bool tensor

mib32 opened this issue · 5 comments

mib32 commented

I am trying to make it work with a network that accepts the input as a boolean tensor, but something is wrong.

In inference_session.rb:226 there is a code

if tensor_type == :bool
  tensor_type = :uchar
  flat_input = flat_input.map { |v| v ? 1 : 0 }
end

So it detects a 'bool' type from the ONNX model, which means that the model is designed to accept bool. Then, it sets the type to uchar.
And for me, what happens next, is that continued inference produces error OnnxRuntime::Error: type 17 is not supported in this function, and as I understand, that kinda makes sense.

One workaround would be to I guess make the ONNX model accept tensor as uchar, and inside of it's forward function convert it back to bool. But for some reason I get weird and inconsistent Gather errors from that. And even more, for this I need to change the architecture of the model, that in the end I actually can not use it for the models, that were trained with network prior to changing architecture.

Other thing I tried, is to do this there (took that code from FFI::Pointer#write_array_of_type)

if tensor_type == :bool
    size = ::FFI.type_size(::FFI::TYPE_BOOL)
    flat_input.each_with_index { |val, i|
      break unless i < input_tensor_values.size
      input_tensor_values.write(::FFI::TYPE_BOOL, val)
    }
end

But that totally doesn't work.

I would love to hear if you have any experience with this. What I don't understand is why it's not possible to just send the array of bools natively, why you even had to make this case for if tensor_type == :bool and convert them to bytes?

mib32 commented

Also asked here at FFI ffi/ffi#867

mib32 commented

Sorry, my fault. It actually does work okay using

if tensor_type == :bool
    size = ::FFI.type_size(::FFI::TYPE_BOOL)
    flat_input.each_with_index { |val, i|
      break unless i < input_tensor_values.size
      input_tensor_values.write(::FFI::TYPE_BOOL, val)
    }
end

I recommend to do it as a default for bool type instead of what there is now, I can send the PR

Hey @mib32, I fixed an error and added tests for the bool type, but feel free to send a PR if there's still an issue or a better approach.

mib32 commented

@ankane Thanks, it's perfect now   👏

Great, just pushed out a new release.