error when using TensorrtExecutionProvider for ONNX models
noahzn opened this issue · 18 comments
Hi, I have another issue. I first optimized only the matcher (lightglue.onnx), and then run the infer with --trt
, but got the following error. When I exported the models to ONNX, I didn't use --dynamic
[E:onnxruntime:, inference_session.cc:1785 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1552 SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const [ONNXRuntimeError] : 1 : FAIL : TensorRT input: /NonZero_output_0 has no shape specified. Please run shape inference on the onnx model first. Details can be found in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs
Do you have any ideas about this? Thank you in advance.
I solved the problem. Now I both do shape_infer on the superpoint and lightglue onnx models (optimized one). But now I have a new problem:
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: /onnxruntime_src/onnxruntime/core/optimizer/transformer_memcpy.cc:374 bool onnxruntime::TransformerMemcpyImpl::ProcessInitializers(const onnxruntime::KernelRegistryManager&, const InitializedTensorSet&) status.IsOK() was false. Failed to find kernel for Identity(16) (node Identity_440). Kernel not found
@fabio-sim Hi, since I am using the open-superpoint, I think the problem is from here
class VGGBlock(nn.Sequential): def __init__(self, c_in, c_out, kernel_size, relu=True): padding = (kernel_size - 1) // 2 conv = nn.Conv2d( c_in, c_out, kernel_size=kernel_size, stride=1, padding=padding ) activation = nn.ReLU(inplace=True) if relu else nn.Identity() bn = nn.BatchNorm2d(c_out, eps=0.001) super().__init__( OrderedDict( [ ("conv", conv), ("activation", activation), ("bn", bn), ] ) )
[2024-06-21 07:47:33 ERROR] If_302_OutputLayer: IIfConditionalOutputLayer inputs must have the same shape. Shapes are [-1,2] and [128].
Do you know how to fix this?
Hi @noahzn, thank you for your interest in LightGlue-ONNX.
At a glance, other than the shape inference problem, I'm not sure what's causing the error. Are you using a newer version of TensorRT?
Hi @noahzn, thank you for your interest in LightGlue-ONNX.
At a glance, other than the shape inference problem, I'm not sure what's causing the error. Are you using a newer version of TensorRT?
My TensorRT is 8.6.1. I don't know what causes the shape issue but regarding the Failed to find kernel for Identity(16)
issue, I think it's caused by the nn.Identity
used in open-superpoint
. For the lightglue code, in your optimize.py code, the nn.Identity
layer is removed. So maybe I also need such an optimize.py
file to process the onnx model of open-superpoint
?
I'd like to clarify something first so that it doesn't turn out to be a waste of time for you: Are you trying to run LightGlue matching using SuperPoint-Open as the extractor? Because currently LightGlue weights (the SuperPoint one) are not compatible with SuperPoint-Open features.
Yes, I'm using the SuperPoint-Open you shared with me last time. I have my own trained weights for both networks, and they work well when using CUDAExecutionProvider. Now I want to use the TRTExecutionProvider
I see. The code you linked:
LightGlue-ONNX/lightglue_onnx/superpoint_open.py
Lines 84 to 100 in fc1d67a
is only related to the model's creation, and shouldn't be a factor during its forward pass. If I recall correctly, there shouldn't even be If ONNX operators present in the model in the first place. This is causing TensorRT EP to convert it to IIfConditionalOutputLayer
.
As for Failed to find kernel for Identity(16)
, I've never encountered this before. Does inference run fine on CUDA EP?
Thank you for your quick reply. Yes, they work well on CUDA EP
For the If error, the only place with a conditional is the select topK keypoints function, so I think that's the cause. Did you set a value for max_num_keypoints
during ONNX export/tracing?
Yes, I set max_num_keypoints
to 128
. I should not set it?
Ah, I see the problem now. I think if you remove this line:
the trace will always go through the TopK branch, and there won't be any conditional/if operators traced. TensorRT cannot handle data-dependent dynamic shapes.
amazing! Although there was a warning when I exported the model TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. if k>=keypoints.shape[0]
,
Then, when I still use infer.py
to infer, it has an error just as in this issue
3: getPluginCreator could not find plugin: MultiHeadAttention version: 1
Should I use trt_infer.py
Lines 1 to 122 in bcf96b7
Oh to solve could not find MultiHeadAttention
, don't run optimize.py
. This is specific to CUDA EP only (com.microsoft
contrib operator).
Now it works without error! But there is no point on the second image, I guess it's caused by my modifying the network code today. I will check tomorrow. I will close this issue since it has been solved. Thank you so much!!
BTW, if I now have a new network and I want to convert it to ONNX and run TRT EP, how do I know which part of the network I should modify to adapt to the ONNX conversion? Are there any tutorials I should learn? Thank you.
You're welcome, @noahzn. Happy I could help.
I guess the closest thing would be PyTorch's tutorials in their docs, but in my experience most tutorials assume that the model to be converted can be converted.
The process of adapting a model so that it can be converted is rarely easy and usually requires customization.
Yes, I have no experience in customizing layers or operators, I only tried converting simple models in which all the operations are already supported by ONNX. Thank you again!
[E:onnxruntime:, inference_session.cc:1785 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1552 SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const [ONNXRuntimeError] : 1 : FAIL : TensorRT input: /NonZero_output_0 has no shape specified. Please run shape inference on the onnx model first. Details can be found in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs
how to fix this? Thank you in advance.
@noahzn
@dashirenyu You need to run symbolic_shape_infer.py to convert the onnx model to trt version. Check the link you just posted.