/ResNet

Resnet train & deploy demo

Primary LanguageC++

1.Train Model

# 直接运行train.py程序即可完成Resnet训练任务,会把训练结果保存在data/pth中
cd train
python3 train.py

训练代码是参考https://www.bilibili.com/video/BV1Xw411f7FW 中讲解的ResNet框架。

这里没用ImageNet训练,而是使用了最简单的MNIST手写数字体数据集做RestNet18的演示demo,训练4个epoch基本就能达到98%的正确率。

2.Deploy Model

部署代码采用的思路是:

pytorch --> pth file --> onnx file --> trt file --> TensorRT Engine(Python/C++)

# onnx 转 tensorRT file指令
trtexec --onnx=path-to-onnx-model/xxx.onnx --saveEngine=path-to-save_trt_model/xxx.trt

这里面最难的是:配置环境 和 TensorRT使用

配置环境请参考:

3.进展更新

  • 2022.04.06 ubuntu18.04+RTX3080 配置好以上NVIDIA环境,并完成训练代码和Python部署推理代码,C++学了下TensorRT样例