/trashNet

垃圾分类:从训练到预测到实现 Web APP

Primary LanguagePython

手机拍照垃圾分类 APP

文件及目录结构

├── templates      
│   └── home.html     # 前端模板
├── onnx_rt_infer.py  # 模型推理
├── server.py         # 服务后端
├── trash40.onnx      # 预训练模型 ONNX 文件
├── requirements.txt  # Python 依赖包
└── run.sh            # OneFlow 云上一键启动脚本

原理架构

server.py 中将通过 Flask 启动过一个 Web 服务。这个 FLASK 服务器前接用户来自浏览器的请求,后接用于推理图片结果的 ONNX Runtime。架构如下所示:

┌───────┐           ┌───────┐        ┌───────┐
│       │    AJAX   │       │        │       │
│       ├───────────►       ├────────►       │
│ USER  │           │ FLASK │        │ONNX RT│
│       ◄───────────┤       ◄────────┤       │
│       │    JSON   │       │        │       │
└───────┘           └───────┘        └───────┘

前端

前端实现代码在 templates/home.html 中,它提供了一个 uploader 用于用户拍照/上传图片:

<van-uploader id="image" name="image" :after-read="afterRead" :max-count="1" />

当用户拍照/上传图片后,通过 axios 发送 POST 请求给后端,并且将后端返回结果(图片地址和预测结果)更新到前端页面上:

              axios.post("", formData).then((res) => {
                console.log('Upload success');
                vue_obj.image_url = res.data.image_url;
                vue_obj.prediction = res.data.prediction;
                vant.Toast(res.data.prediction);
              });

因为用户拍摄的照片往往较大,所以上传照片前,会预先对照片进行压缩:

          new Compressor(file.file, {
              //...
          });

后端服务

server.py 中,注册了两个路由:

@app.route("/v1/index", methods=["GET", "POST"])
#...

@app.route("/images/<filename>", methods=["GET"])
#...

/v1/index 是核心路由,当接受 GET 请求时,返回主页面(home.html)。当接受 POST 请求时,则做两件事情:

  1. 保存图片(方便之后被前端引用显示)
  2. 对图片进行预测并返回预测结果

保存图片相关代码:

                filename = generate_filenames(image_file.filename)
                filePath = os.path.join(app.config["UPLOAD_FOLDER"], filename)
                image_file.save(filePath)

对图片进行预测相关代码:

def predict(filename):
    image_url = url_for("images", filename=filename)
    image_file_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
    prediction = make_prediction(image_file_path)
    json_obj = {"image_url": image_url, "prediction": prediction}
    return json.dumps(json_obj)

后端推理

后端推理使用 ONNX Runtime,通过读取已经转好的 trash40.onnx 文件初始化推理 session。

self.ort_sess = ort.InferenceSession("trash40.onnx")

通过调用 sessoinrun 方法进行推理分类结果:

output = self.ort_sess.run(None, {"input": image})

注意,为了减少对 torch vision 的依赖,使用 Numpy 对数据增强部分做了重写:

    def image_process(self, image):
        # resize with keeping ratio
        # ...

        # crop
        # ...

        # normalize
        # ...

根据推理得到的概率,找到概率最大的分类,并反查得到分类的名称:

        output = self.ort_sess.run(None, {"input": image})
        output = output[0]
        predict_idx = self.find_max(output[0].tolist())
        real_index = self.INDEX_MAP[predict_idx]
        return self.NAME_MAP[str(real_index)]

如何部署

部署

一、 点击 部署 按钮

image.png

二、如图,勾选 trashNet 文件夹后,点击下一步 image.png

三、填写必要的信息后,点击“下一步”

image.png

四、选择对应的镜像,在启动命令中填入 sh /workspace/trashNet/run.sh,端口选择 5000(因为我们的 server.py 中服务监听的是5000端口。点击“确定”

image.png

运行

可以在部署成功后直接点击“运行”,也可以在在线推理页面,点击“运行”

image.png

选择对应的环境后,点击“确定”即可。

image.png

测试使用

项目启动后,点击 “测试”即可获取对应的 URL

image.png

该项目只适配了移动端,未做 PC 端的适配。

image.png