/Models

采用MegEngine实现的各种主流深度学习模型

Primary LanguagePythonOtherNOASSERTION

MegEngine Models

本仓库包含了采用MegEngine实现的各种主流深度学习模型。

official目录下提供了各种经典的图像分类、目标检测、图像分割以及自然语言模型的官方实现。每个模型同时提供了模型定义、推理以及训练的代码。

官方会一直维护official下的代码,保持适配MegEngine的最新API,提供最优的模型实现。同时,提供高质量的学习文档,帮助新手学习如何在MegEngine下训练自己的模型。

综述

对于每个模型,我们提供了至少四个脚本文件:模型定义(model.py)、模型推理(inference.py)、模型训练(train.py)、模型测试(test.py)。

每个模型目录下都对应有一个README,介绍了模型的详细信息,并详细描述了训练和测试的流程。例如 ResNet README

另外,official下定义的模型可以通过megengine.hub来直接加载,例如:

import megengine.hub

# 只加载网络结构
resnet18 = megengine.hub.load("megengine/models", "resnet18")
# 加载网络结构和预训练权重
resnet18 = megengine.hub.load("megengine/models", "resnet18", pretrained=True)

更多可以通过megengine.hub接口加载的模型见hubconf.py

安装和环境配置

在开始运行本仓库下的代码之前,用户需要通过以下步骤来配置本地环境:

  1. 克隆仓库
git clone https://github.com/MegEngine/Models.git
  1. 安装依赖包
pip3 install --user -r requirements.txt
  1. 添加目录到python环境变量中
export PYTHONPATH=/path/to/models:$PYTHONPATH

官方模型介绍

图像分类

图像分类是计算机视觉的基础任务。许多计算机视觉的其它任务(例如物体检测)都使用了基于图像分类的预训练模型。因此,我们提供了各种在ImageNet上预训练好的分类模型,包括ResNet系列, shufflenet系列等,这些模型在ImageNet验证集上的测试结果如下表:

模型 top1 acc top5 acc
ResNet18 70.312 89.430
ResNet34 73.960 91.630
ResNet50 76.254 93.056
ResNet101 77.944 93.844
ResNet152 78.582 94.130
ResNeXt50 32x4d 77.592 93.644
ShuffleNetV2 x0.5 60.696 82.190
ShuffleNetV2 x1.0 69.372 88.764
ShuffleNetV2 x1.5 72.806 90.792
ShuffleNetV2 x2.0 75.074 92.278

目标检测

目标检测同样是计算机视觉中的常见任务,我们提供了一个经典的目标检测模型retinanet,这个模型在COCO验证集上的测试结果如下:

模型 mAP
@5-95
retinanet-res50-1x-800size 36.0

图像分割

我们也提供了经典的语义分割模型--Deeplabv3plus,这个模型在PASCAL VOC验证集上的测试结果如下:

模型 Backbone mIoU_single mIoU_multi
Deeplabv3plus Resnet101 79.0 79.8

自然语言处理

我们同样支持一些常见的自然语言处理模型,模型的权重来自Google的pre-trained models, 用户可以直接使用megengine.hub轻松的调用预训练的bert模型。

另外,我们在bert中还提供了更加方便的脚本, 可以通过任务名直接获取到对应字典, 配置, 与预训练模型。

模型 字典 配置
wwm_cased_L-24_H-1024_A-16 link link
wwm_uncased_L-24_H-1024_A-16 link link
cased_L-12_H-768_A-12 link link
cased_L-24_H-1024_A-16 link link
uncased_L-12_H-768_A-12 link link
uncased_L-24_H-1024_A-16 link link
chinese_L-12_H-768_A-12 link link
multi_cased_L-12_H-768_A-12 link link

在glue_data/MRPC数据集中使用默认的超参数进行微调和评估,评估结果介于84%和88%之间。

Dataset pretrained_bert acc
glue_data/MRPC uncased_L-12_H-768_A-12 86.25%