本项目是对ThiNet的一个复现,和论文中稍微有点不同 不同之处在于: 1.论文中提出的方法是每剪枝一层就微调一次,然而这在pytorch中很难去实现,因此我的做法是将所有层剪枝完之后再微调。 2.论文中没有对全连接层进行剪枝,我的做法是在全连接层使用基于阈值的方法将低于阈值的神经元删除。
numpy==1.22.3
torch==1.11.0+cu113
torchvision==0.12.0+cu113
这在requirements.txt文件中可以看到
运行train_vgg.py文件,参数设置:
batch_size = 64
test_batch_size = 256
lr = 0.01
momentum = 0.9
weight_decay = 1e-4
epochs = 200
训练结果:
模型 | 数据集 | FLOPS | 模型大小 | 测试集精度 |
---|---|---|---|---|
vgg_19 | cifar10 | 1.24G | 148MB | 92.1% |
在命令行中输入tensorboard --logdir vgg_tb,可以在tensorboard中查看训练过程的可视化
运行vgg_prune.py文件,微调epoch=20,剪枝之前训练20个epoch所需时间为1026s,剪枝前的精度为92.1%.
conv(0.3)表示剪掉卷积层的30%,fc(0.5)表示剪掉全连接层的50%
结果:
微调后精度 | 剪枝后模型大小 | 剪枝前FLOPS | 剪枝后FLOPS | 剪枝时长 | 微调时长 | m | |
---|---|---|---|---|---|---|---|
conv(0.3) | 90.2% | 111MB | 1.24G | 0.63G | 29s | 723s | 16 |
conv(0.5) | 88% | 93.6MB | 1.24G | 0.35G | 59s | 574s | 16 |
conv(0.3)+fc(0.5) | 90.27% | 59.4MB | 1.24G | 0.60G | 29s | 650s | 16 |
conv(0.3)+fc(0.75) | 90.1% | 45.3MB | 1.24G | 0.59G | 29s | 649s | 16 |
conv(0.5)+fc(0.85) | 88.7% | 24.0MB | 1.24G | 0.31G | 59s | 570s | 16 |
conv(0.5)+fc(0.85) | 88.79% | 24.0MB | 1.24G | 0.31G | 647s | 552s | 256 |
在上表中可以看到,对卷积层剪枝的比例越大,剪枝所花时间越长,微调后的精度也越低,但是训练时间大幅降低。对全连接层剪枝的比例越大,模型压缩的越小,精度越低。因此需要在卷积层和全连接层的剪枝比例上做一个协调。
运行vgg_prune_distillation.py文件,剪枝之后使用知识蒸馏进行微调,原来的模型作为老师模型,剪枝之后的模型作为学生模型。
微调后精度 | 剪枝后模型大小 | 剪枝前FLOPS | 剪枝后FLOPS | 剪枝时长 | 微调时长 | m | |
---|---|---|---|---|---|---|---|
conv(0.3) | 90.33% | 111MB | 1.24G | 0.63G | 29s | 934s | 16 |
conv(0.5) | 89.02% | 93.6MB | 1.24G | 0.35G | 63s | 794s | 16 |
conv(0.3)+fc(0.5) | 90.43% | 59.4MB | 1.24G | 0.60G | 38s | 1040s | 16 |
conv(0.3)+fc(0.75) | 90.16% | 45.3MB | 1.24G | 0.59G | 30s | 873s | 16 |
conv(0.5)+fc(0.85) | 89.2% | 24.0MB | 1.24G | 0.31G | 60s | 773s | 16 |
conv(0.5)+fc(0.85) | 89.37% | 24.0MB | 1.24G | 0.31G | 655s | 770s | 256 |
传统微调精度 | 知识蒸馏微调精度 | 传统微调时长 | 知识蒸馏微调时长 | m | |
---|---|---|---|---|---|
conv(0.3) | 90.2% | 90.33% | 723s | 934s | 16 |
conv(0.5) | 88% | 89.02% | 574s | 794s | 16 |
conv(0.3)+fc(0.5) | 90.27% | 90.43% | 650s | 1040s | 16 |
conv(0.3)+fc(0.75) | 90.1% | 90.16% | 649s | 873s | 16 |
conv(0.5)+fc(0.85) | 88.7% | 89.2% | 570s | 773s | 16 |
conv(0.5)+fc(0.85) | 88.79% | 89.37% | 552s | 770s | 256 |
通过上表可以看出来,在精度恢复方面,知识蒸馏的微调方式要比普通的微调方式效果更好,但是使用知识蒸馏的微调方式所需的微调时间要更长,这是因为知识蒸馏过程不仅要计算学生模型的前向传播,还要计算教师模型的前向传播。当把m设为256时,剪枝所花的时间却高了大约十倍左右(相比m=16时),然而微调之后的精度却仅仅提高了0.1%~%0.2左右,这与花费的时间相比,这不是一个令人满意的结果。