grimoire/mmdetection-to-tensorrt

Wrong detection boxes with FCOS detector

guillaume-michel opened this issue · 8 comments

Describe the bug
When using demo/inference.py to perform inference with FCOS detector, the resulting detection boxes all have (x1, y1) coordinates equal to (0, 0). This leads to the following detections:
image

After investigations, we found that the bboxes computed here: https://github.com/grimoire/mmdetection-to-tensorrt/blob/master/mmdet2trt/models/dense_heads/fcos_head.py#L62 are corrects (x1 and y1 coordinates are usually not 0) but after the end of the loop here: https://github.com/grimoire/mmdetection-to-tensorrt/blob/master/mmdet2trt/models/dense_heads/fcos_head.py#L67, the (x1, y1) coordinates are all zeros except for the last scale.

Can you please have a look?

To Reproduce
python3 demo/inference.py base_path/demo/demo.jpg base_path/configs/fcos/fcos_r50_caffe_fpn_gn-head_1x_coco.py base_path/checkpoints/fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth /dev/shm/fcos.engine

enviroment:

  • OS: Ubuntu 20.04
  • python_version: 3.8.10
  • pytorch_version: 1.9.0+cu111
  • cuda_version: cuda-11.3
  • cudnn_version: 8.2.4.15
  • mmdetection_version: fe396572839f30351aa862d2b2702f57698cefea
  • TensorRT: 8.0.3

Additional context
Executed on RTX 3060 for Laptop

Hi, I have tested on 1660ti and 2070s, with both PyTorch 1.8 and 1.9, TensorRT8.0.3, the code works on my side.
Can you provide a dockerfile (you can start from NGC ) so I can reproduce the error?

I have the same Problem: #74
Reconverting can help, but is not a satisfactory solution. One model still had the problem even after frequent re-converting.

Hi,
Thanks for your reply.

Attached are two Dockerfiles: 1 where the bug is present and 1 where the bug is NOT present. The only difference is the CUDA, pytorch and TensorRT versions. They are defined at the TOP of the Dockerfile.

mmdet2trt_not_working.zip
mmdet2trt_working.zip

For reference, the versions that WORKS are the following:

  • Ubuntu: 20.04
  • CUDA: 11.0.3
  • Pytorch: 1.7.1
  • torchvision: 0.8.2
  • TensorRT: 8.0.0.1

Best regards!

@grimoire Hi, Can you please tell which cuda version you are using with pytorch 1.8.0/1.9.0 and TensorRT 8.0.3? Is it CUDA 11.1?

I can reproduce the error now after update the TensorRT version. And I will try to fix it ASAP.
You can downgrade to 7.2 for now.

Let me know if I can help

Hi,
I think the error is caused by clamping with shape tensor.
I have a fixing in this branch of torch2trt_dynamic. Please try and see if it can help.
I will merge it after more tests.

It fixes the problem! Thank you very much!
I now have warnings you may be interested in:

[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 1143) [ElementWise]_output and (Unnamed Layer* 1147) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 1150) [ElementWise]_output and (Unnamed Layer* 1154) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 1157) [ElementWise]_output and (Unnamed Layer* 1161) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 1803) [ElementWise]_output and (Unnamed Layer* 1807) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 1810) [ElementWise]_output and (Unnamed Layer* 1814) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 1817) [ElementWise]_output and (Unnamed Layer* 1821) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 1824) [ElementWise]_output and (Unnamed Layer* 1828) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 2470) [ElementWise]_output and (Unnamed Layer* 2474) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 2477) [ElementWise]_output and (Unnamed Layer* 2481) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 2484) [ElementWise]_output and (Unnamed Layer* 2488) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 2491) [ElementWise]_output and (Unnamed Layer* 2495) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 3137) [ElementWise]_output and (Unnamed Layer* 3141) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 3144) [ElementWise]_output and (Unnamed Layer* 3148) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 3151) [ElementWise]_output and (Unnamed Layer* 3155) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 3158) [ElementWise]_output and (Unnamed Layer* 3162) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 3804) [ElementWise]_output and (Unnamed Layer* 3808) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 3811) [ElementWise]_output and (Unnamed Layer* 3815) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 3818) [ElementWise]_output and (Unnamed Layer* 3822) [Shuffle]_output: first input has type Float but second input has type Int32.
[TensorRT] WARNING: IElementWiseLayer with inputs (Unnamed Layer* 3825) [ElementWise]_output and (Unnamed Layer* 3829) [Shuffle]_output: first input has type Float but second input has type Int32.

CUDA: 11.3
TensorRT: 8.0.3
PyTorch: 1.10.0.dev20210921+cu113