/Kolors-TensorRT-libtorch

Kolors with TensorRT and libtorch

Primary LanguageC++

Kolors-TensorRT-libtorch

TensorRTlibtorch简单实现了Kolors模型的pipeline推理。

准备

  • 安装TensorRT, TensorRT10的api相较于TensorRT8以下版本变化较大, 目前本仓库做了TensorRT10的适配, 建议用TensorRT10以上的版本。
  • 从huggingface下载模型。
  • 安装pytorch, onnx等依赖。

导出3个onnx模型用于pipeline

修改export_onnx.py中的路径相关信息。 执行:

python export_onnx.py

你会得到text_encoder, unet, vae三个onnx模型。 你可以用onnxsim将它们简化。 pr-336适配了超过2GB的onnx简化报错,可以尝试安装最新的onnxsim。

执行:

onnxsim text_encoder.onnx text_encoder-sim.onnx --save-as-external-data
onnxsim unet.onnx unet-sim.onnx --save-as-external-data
onnxsim vae.onnx vae-sim.onnx

onnx很大的情况下, 简化的耗时也很长。

onnx转换到tensorrt

这里我用了trtexec转化, 比较省事。 目前测试text_encoder部分fp16掉点情况比较大,建议回退到fp32。

trtexec --onnx=text_encoder-sim.onnx --saveEngine=text_encoder.plan --noTF32
trtexec --onnx=unet-sim.onnx --saveEngine=unet.plan --fp16
trtexec --onnx=vae-sim.onnx --saveEngine=vae.plan --fp16

tensorrt转换的过程也很慢。

编译安装python包

执行:

python setup.py install

包名是: py_kolors

推理一个文生图

修改run.py中的3个模型路径, 修改推理步数, 默认50比较慢.

执行:

python run.py

生成的图片会保存为tmp.jpg