fabio-sim/LightGlue-ONNX

The output shape of lightglue's onnx model is dynamic. Does tensorrt support dynamic output?

weihaoysgs opened this issue · 4 comments

@fabio-sim Hi, I noticed that the settings for the output shape in the new script file you uploaded for using tensnsorrt to infer lightglue are fixed values.

# TODO: Still haven't figured out dynamic output shapes yet:
if binding == "matches0":
    shape = (512, 2)
elif binding == "mscores0":
    shape = (512,)

However, the onnx model seems to be set dynamically when exported, so I have the following two questions to ask you.

  1. What if the number of matching points obtained exceeds 512? But it seems that the maximum number of feature points configured in the model you are using is 512, so it will not exceed it. In other words, try to set the output dimension as large as possible?

  2. Since the fixed value is set during inference, can it be directly set to a fixed shape when exporting the onnx model? How to operate?

Because I encountered the following problems when using tensorrt c++ excuse for reasoning, I guess it may be caused by the above reasons.

[12/18/2023-23:29:12] [I] [TRT] Loaded engine size: 52 MiB
[12/18/2023-23:29:12] [I] [TRT] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +43, now: CPU 0, GPU 43 (MiB)
[12/18/2023-23:29:12] [I] [TRT] Loaded engine size: 3 MiB
[12/18/2023-23:29:12] [I] [TRT] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +2, now: CPU 0, GPU 45 (MiB)
[12/18/2023-23:29:12] [I] [TRT] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +582, now: CPU 0, GPU 627 (MiB)
[12/18/2023-23:29:12] [I] [TRT] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +582, now: CPU 0, GPU 627 (MiB)
[12/18/2023-23:29:13] [I] [TRT] [MS] Running engine with multi stream info
[12/18/2023-23:29:13] [I] [TRT] [MS] Number of aux streams is 2
[12/18/2023-23:29:13] [I] [TRT] [MS] Number of total worker streams is 3
[12/18/2023-23:29:13] [I] [TRT] [MS] The main stream provided by execute/enqueue calls is the first worker stream
[12/18/2023-23:29:13] [I] [TRT] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +1137, now: CPU 0, GPU 1182 (MiB)
[12/18/2023-23:29:13] [E] [TRT] 1: [runner.cpp::executeMyelinGraph::715] Error Code 1: Myelin ([myelinGraphExecute] Called without resolved dynamic shapes.)

Looking forward to your reply!

I tried changing the export.py script as follows

torch.onnx.export(
            lightglue,
            (kpts0, kpts1, desc0, desc1),
            lightglue_path,
            input_names=["kpts0", "kpts1", "desc0", "desc1"],
            output_names=["matches0", "mscores0"],
            opset_version=17,
            dynamic_axes={
                "kpts0": {1: "num_keypoints0"},
                "kpts1": {1: "num_keypoints1"},
                "desc0": {1: "num_keypoints0"},
                "desc1": {1: "num_keypoints1"},
                # "matches0": {0: "num_matches0"},
                # "mscores0": {0: "num_matches0"},
            },
        )

the dynamic axes of "matches0" and "mscores0" have been annotated, but the onnx export model maybe have also dynamic output?

2023-12-19_13-34

Hi @weihaoysgs

I'm no expert at TensorRT, so I'm also still not sure how to make dynamic output shapes work there. However, I suspect that the following error is about a different thing.

[12/18/2023-23:29:13] [E] [TRT] 1: [runner.cpp::executeMyelinGraph::715] Error Code 1: Myelin ([myelinGraphExecute] Called without resolved dynamic shapes.)

At runtime, a shape still needs to be set for the inputs, e.g.,:

LightGlue-ONNX/trt_infer.py

Lines 103 to 104 in bcf96b7

for name, shape in shapes.items():
context.set_input_shape(name, tuple(shape))

Regarding the ONNX model, regardless of whether dynamic axes were specified during export or not, the output is still dynamic due to the filter_matches() function here:

def filter_matches(scores: torch.Tensor, th: float):
"""obtain matches from a log assignment matrix [BxMxN]"""
max0 = torch.topk(scores, k=1, dim=2, sorted=False) # scores.max(2)
max1 = torch.topk(scores, k=1, dim=1, sorted=False) # scores.max(1)
m0, m1 = max0.indices[:, :, 0], max1.indices[:, 0, :]
indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
# indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
mutual0 = indices0 == m1.gather(1, m0)
# mutual1 = indices1 == m0.gather(1, m1)
max0_exp = max0.values[:, :, 0].exp()
zero = max0_exp.new_tensor(0)
mscores0 = torch.where(mutual0, max0_exp, zero)
# mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
valid0 = mscores0 > th
# valid1 = mutual1 & valid0.gather(1, m1)
# m0 = torch.where(valid0, m0, -1)
# m1 = torch.where(valid1, m1, -1)
# return m0, m1, mscores0, mscores1
m_indices_0 = indices0[valid0]
m_indices_1 = m0[0][m_indices_0]
matches = torch.stack([m_indices_0, m_indices_1], -1)
mscores = mscores0[0][m_indices_0]
return matches, mscores

One way to avoid this and have a computable (shape-dependent, but no longer data-dependent) output shape is to perform this filtering as post-processing outside the model, similar to #58.

@fabio-sim Hi, I have set the input dynamic shape like this

const int keypoints_0_index = mEngine->getBindingIndex(lgConfig.inputTensorNames[0].c_str());
const int keypoints_1_index = mEngine->getBindingIndex(lgConfig.inputTensorNames[1].c_str());
const int descriptors_0_index = mEngine->getBindingIndex(lgConfig.inputTensorNames[2].c_str());
const int descriptors_1_index = mEngine->getBindingIndex(lgConfig.inputTensorNames[3].c_str());

const int output_matcher0_index = mEngine->getBindingIndex(lgConfig.outputTensorNames[0].c_str());
const int output_score0_index = mEngine->getBindingIndex(lgConfig.outputTensorNames[1].c_str());

mContext->setBindingDimensions(keypoints_0_index, nvinfer1::Dims3(1, features0.cols(), 2));
mContext->setBindingDimensions(keypoints_1_index, nvinfer1::Dims3(1, features1.cols(), 2));
mContext->setBindingDimensions(descriptors_0_index, nvinfer1::Dims3(1, features0.cols(), 256));
mContext->setBindingDimensions(descriptors_1_index, nvinfer1::Dims3(1, features1.cols(), 256));

I will conduct more detailed tests, thanks for your reply

@fabio-sim Hi, Thank you for your suggestion. I still put the post-processing in C++ and did not put the post-processing part into the onnx model. The above error disappeared. I will close this question and if there are new ones, I will open another.