/AliProducts

一个通用的图像分类模板,天池/CVPR AliProducts挑战赛 3/688

Primary LanguagePythonMIT LicenseMIT

CVPR 2020 AliProducts Challenge

一个通用的图像分类模板,天池/CVPR AliProducts Challenge 3/688🍟

队伍:薯片分类器!

preview

Features

  • Backbone

    • ResNet(101)
    • ResNeXt(101)
    • ResNeSt(101, 200)
    • Res2Net(101)
    • iResNet(101, 152, 200)
    • EffiCientNet(B-5, B-7)
  • 优化器

    • Adam
    • SGD
    • Ranger(RAdam+Look Ahead)
  • Scheduler

    • Cos
    • 自定义scheduler
  • Input Pipeline

    • 裁剪和切割
    • 随机翻折和旋转
    • 随机放大
    • 随机色相
    • 随机饱和度
    • 随机亮度
    • Norm_input
  • 其他tricks

    • label smooth
    • model ensemble
    • TTA

Prerequisites

python >= 3.6
torch >= 1.0
tensorboardX >= 1.6
utils-misc >= 0.0.5
torch-template >= 0.0.4
mscv >= 0.0.3

都是很好装的库,不需要编译。

Code Usage

Code Usage:
Training:
    python train.py --tag your_tag --model ResNeSt101 --epochs 20 -b 24 --lr 0.0001 --gpu 0

Finding Best Hyper Params:  # 需先设置好sweep.yml
    python grid_search.py --run

Resume Training (or fine-tune):
    python train.py --tag your_tag --model ResNeSt101 --epochs 20 -b 24 --load checkpoints/your_tag/9_ResNeSt101.pt --resume --gpu 0

Eval:
    python eval.py --model ResNeSt101 -b 96 --load checkpoints/your_tag/9_ResNeSt101.pt --gpu 1

Generate Submission:
    python submit.py --model ResNeSt101 --load checkpoints/your_tag/9_ResNeSt101.pt -b 96 --gpu 0

Check Running Log:
    cat logs/your_tag/log.txt

Clear(delete all files with the tag, BE CAREFUL to use):
    python clear.py --tag your_tag

See ALL Running Commands:
    cat run_log.txt

参数用法:

--tag参数是一次操作(traineval)的标签,日志会保存在logs/标签目录下,保存的模型会保存在checkpoints/标签目录下。

--model是使用的模型,所有可用的模型定义在network/__init__.py中。

--epochs是训练的代数。

-b参数是batch_size,可以根据显存的大小调整。

--lr是初始学习率。

--load是加载预训练模型。

--resume配合--load使用,会恢复上次训练的epoch和优化器。

--gpu指定gpu id,目前只支持单卡训练。

--debug以debug模式运行,debug模式下每个epoch只会训练前几个batch。

另外还可以通过参数调整优化器、学习率衰减、验证和保存模型的频率等,详细请查看python train.py --help

如何添加自定义的模型:

如何添加新的模型:

① 复制network目录下的Default文件夹,改成另外一个名字(比如MyNet)。

② 在network/__init__.py中import你的Model并且在models = {}中添加它。
    from MyNet.Model import Model as MyNet
    models = {
        'default': Default,
        'MyNet': MyNet,
    }

③ 尝试 python train.py --model MyNet --debug 看能否成功运行