/TensorRT_Python_API

用TensorRT的Python API去搭建trt网络

Primary LanguagePython

TensorRT_Python_API

用TensorRT的Python API去搭建trt网络

执行脚本python3 sample.py, 将会对网络进行训练, 训练完成之后, 将会在当前目录下将模型的所有参数保存成名为torchPara.npz的文件, 同时还会将PyTorch模型导出成名为test.onnx的onnx模型.

再执行python3 construct_your_model.py就可以采用TensorRT的原生API搭建模型了, 并将最后的模型序列化成mynet.engine文件. 该TRT模型是支持动态BatchSize的.

通过trtexec命令将生成的onnx文件也转成trt文件.

trtexec --onnx=test.onnx \
        --workspace=1600 \
        --explicitBatch \
        --optShapes=data:4x1x28x28 \
        --maxShapes=data:8x1x28x28 \
        --minShapes=data:1x1x28x28 \
        --shapes=data:4x1x28x28 \
        --saveEngine=test.engine