版本:v1.0.20200502
数据集在文件夹pokemon内,共有5大分类,1168张图片。包含了妙蛙种子234张、小火龙238张、杰尼龟223张、皮卡丘234张、超梦239张。
模型使用的是ResNet18,训练了10个epochs,使用visdom可视化准确率和损失。最终对比了迁移学习和自定义的残差网络效果。项目主要文件如下:
- pokemon.py 用于加载和标记数据集
- resnet.py 自己写的残差网络
- train_scratch.py 使用自己写的残差网络训练
- train_transfer.py 使用已训练过的残差网络训练(迁移学习)
- utils.py 一些辅助方法(打平数据和显示图像)
- best_scratch.mdl / best_transfer.mdl 最佳准确率的参数,有则该无则增。可以在train之前删掉。
预处理主要有图片的resize操作,数据增强,标准化和转换成tensor,加载标记和切分数据集。
train_scratch是自己搭建的ResNet18,train_transfer是使用已经训练好的ResNet18进行迁移学习。后者效果更佳!
训练文件 | 最佳准确率 best acc | 测试集准确率 test acc |
---|---|---|
train_scratch | 89.6% | 92% |
train_transfer | 94.8% | 95.2% |
虽然整体的数据集较小(1168张图片),但是迁移学习比自定义的ResNet效果更佳,准确率可以达到95%+。
其他待研究:增加训练epochs,修改学习率,数据增强等操作能否再提高准确率?