├── templates
│ └── home.html # 前端模板
├── onnx_rt_infer.py # 模型推理
├── server.py # 服务后端
├── trash40.onnx # 预训练模型 ONNX 文件
├── requirements.txt # Python 依赖包
└── run.sh # OneFlow 云上一键启动脚本
server.py
中将通过 Flask
启动过一个 Web 服务。这个 FLASK 服务器前接用户来自浏览器的请求,后接用于推理图片结果的 ONNX Runtime。架构如下所示:
┌───────┐ ┌───────┐ ┌───────┐
│ │ AJAX │ │ │ │
│ ├───────────► ├────────► │
│ USER │ │ FLASK │ │ONNX RT│
│ ◄───────────┤ ◄────────┤ │
│ │ JSON │ │ │ │
└───────┘ └───────┘ └───────┘
前端实现代码在 templates/home.html
中,它提供了一个 uploader
用于用户拍照/上传图片:
<van-uploader id="image" name="image" :after-read="afterRead" :max-count="1" />
当用户拍照/上传图片后,通过 axios
发送 POST 请求给后端,并且将后端返回结果(图片地址和预测结果)更新到前端页面上:
axios.post("", formData).then((res) => {
console.log('Upload success');
vue_obj.image_url = res.data.image_url;
vue_obj.prediction = res.data.prediction;
vant.Toast(res.data.prediction);
});
因为用户拍摄的照片往往较大,所以上传照片前,会预先对照片进行压缩:
new Compressor(file.file, {
//...
});
在 server.py
中,注册了两个路由:
@app.route("/v1/index", methods=["GET", "POST"])
#...
@app.route("/images/<filename>", methods=["GET"])
#...
/v1/index
是核心路由,当接受 GET
请求时,返回主页面(home.html
)。当接受 POST
请求时,则做两件事情:
- 保存图片(方便之后被前端引用显示)
- 对图片进行预测并返回预测结果
保存图片相关代码:
filename = generate_filenames(image_file.filename)
filePath = os.path.join(app.config["UPLOAD_FOLDER"], filename)
image_file.save(filePath)
对图片进行预测相关代码:
def predict(filename):
image_url = url_for("images", filename=filename)
image_file_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
prediction = make_prediction(image_file_path)
json_obj = {"image_url": image_url, "prediction": prediction}
return json.dumps(json_obj)
后端推理使用 ONNX Runtime,通过读取已经转好的 trash40.onnx
文件初始化推理 session。
self.ort_sess = ort.InferenceSession("trash40.onnx")
通过调用 sessoin
的 run
方法进行推理分类结果:
output = self.ort_sess.run(None, {"input": image})
注意,为了减少对 torch vision 的依赖,使用 Numpy
对数据增强部分做了重写:
def image_process(self, image):
# resize with keeping ratio
# ...
# crop
# ...
# normalize
# ...
根据推理得到的概率,找到概率最大的分类,并反查得到分类的名称:
output = self.ort_sess.run(None, {"input": image})
output = output[0]
predict_idx = self.find_max(output[0].tolist())
real_index = self.INDEX_MAP[predict_idx]
return self.NAME_MAP[str(real_index)]
一、 点击 部署 按钮
三、填写必要的信息后,点击“下一步”
四、选择对应的镜像,在启动命令中填入 sh /workspace/trashNet/run.sh
,端口选择 5000
(因为我们的 server.py
中服务监听的是5000端口。点击“确定”
可以在部署成功后直接点击“运行”,也可以在在线推理页面,点击“运行”
选择对应的环境后,点击“确定”即可。
项目启动后,点击 “测试”即可获取对应的 URL
该项目只适配了移动端,未做 PC 端的适配。