/unet-tensorrt

this is a tensorrt version unet, inspired by tensorrtx

Primary LanguageC++MIT LicenseMIT

tensorrt-unet

update

For the newest tensorrt 8.x version unet, please check this repo https://github.com/wang-xinyu/tensorrtx/tree/master/unet, they have more updates and are better dependency friendly.

屏幕截图 2021-05-07 153556

original img(left) and segmentation result(right)

This is a TensorRT version Unet, inspired by tensorrtx and pytorch-unet.
You can generate TensorRT engine file using this script and customize some params and network structure based on network you trained (FP32/16 precision, input size, different conv, activation function...)

requirements

TensorRT 7.0 (you need to install tensorrt first)
Cuda 10.2
Python3.7
opencv 4.4
cmake 3.18

train .pth file and convert .wts

create env

pip install -r requirements.txt

train .pth file

train your dataset by following pytorch-unet and generate .pth file.

convert .wts

run gen_wts from utils folder, and move it to project folder (you need to run with east training environment)(

generate engine file and infer

create build folder in project folder

mkdir build

make file, generate exec file

cd build
cmake ..
make

generate TensorRT engine file and infer image

unet -s

then a unet exec file will generated, you can use unet -d to infer files in a folder

unet -d ../samples

efficiency

the speed of tensorRT engine is much faster(testing on 2080TI)

pytorch TensorRT FP32 TensorRT FP16
816x672 816x672 816x672
58ms 43ms (batchsize 8) 14ms (batchsize 8)

Further development

  1. add INT8 calibrator
  2. add custom plugin
    etc