/ThiNet

Primary LanguagePython

ThiNet: A Filter Level Pruning Method for Deep Neural Network Compression

本项目是对ThiNet的一个复现,和论文中稍微有点不同 不同之处在于: 1.论文中提出的方法是每剪枝一层就微调一次,然而这在pytorch中很难去实现,因此我的做法是将所有层剪枝完之后再微调。 2.论文中没有对全连接层进行剪枝,我的做法是在全连接层使用基于阈值的方法将低于阈值的神经元删除。

项目依赖

numpy==1.22.3

torch==1.11.0+cu113

torchvision==0.12.0+cu113

这在requirements.txt文件中可以看到

训练

运行train_vgg.py文件,参数设置:

batch_size = 64

test_batch_size = 256

lr = 0.01

momentum = 0.9

weight_decay = 1e-4

epochs = 200

训练结果:

模型 数据集 FLOPS 模型大小 测试集精度
vgg_19 cifar10 1.24G 148MB 92.1%

在命令行中输入tensorboard --logdir vgg_tb,可以在tensorboard中查看训练过程的可视化

剪枝-微调

运行vgg_prune.py文件,微调epoch=20,剪枝之前训练20个epoch所需时间为1026s,剪枝前的精度为92.1%.

conv(0.3)表示剪掉卷积层的30%,fc(0.5)表示剪掉全连接层的50%

结果:

微调后精度 剪枝后模型大小 剪枝前FLOPS 剪枝后FLOPS 剪枝时长 微调时长 m
conv(0.3) 90.2% 111MB 1.24G 0.63G 29s 723s 16
conv(0.5) 88% 93.6MB 1.24G 0.35G 59s 574s 16
conv(0.3)+fc(0.5) 90.27% 59.4MB 1.24G 0.60G 29s 650s 16
conv(0.3)+fc(0.75) 90.1% 45.3MB 1.24G 0.59G 29s 649s 16
conv(0.5)+fc(0.85) 88.7% 24.0MB 1.24G 0.31G 59s 570s 16
conv(0.5)+fc(0.85) 88.79% 24.0MB 1.24G 0.31G 647s 552s 256

在上表中可以看到,对卷积层剪枝的比例越大,剪枝所花时间越长,微调后的精度也越低,但是训练时间大幅降低。对全连接层剪枝的比例越大,模型压缩的越小,精度越低。因此需要在卷积层和全连接层的剪枝比例上做一个协调。

剪枝-知识蒸馏

运行vgg_prune_distillation.py文件,剪枝之后使用知识蒸馏进行微调,原来的模型作为老师模型,剪枝之后的模型作为学生模型。

微调后精度 剪枝后模型大小 剪枝前FLOPS 剪枝后FLOPS 剪枝时长 微调时长 m
conv(0.3) 90.33% 111MB 1.24G 0.63G 29s 934s 16
conv(0.5) 89.02% 93.6MB 1.24G 0.35G 63s 794s 16
conv(0.3)+fc(0.5) 90.43% 59.4MB 1.24G 0.60G 38s 1040s 16
conv(0.3)+fc(0.75) 90.16% 45.3MB 1.24G 0.59G 30s 873s 16
conv(0.5)+fc(0.85) 89.2% 24.0MB 1.24G 0.31G 60s 773s 16
conv(0.5)+fc(0.85) 89.37% 24.0MB 1.24G 0.31G 655s 770s 256

总结

传统微调精度 知识蒸馏微调精度 传统微调时长 知识蒸馏微调时长 m
conv(0.3) 90.2% 90.33% 723s 934s 16
conv(0.5) 88% 89.02% 574s 794s 16
conv(0.3)+fc(0.5) 90.27% 90.43% 650s 1040s 16
conv(0.3)+fc(0.75) 90.1% 90.16% 649s 873s 16
conv(0.5)+fc(0.85) 88.7% 89.2% 570s 773s 16
conv(0.5)+fc(0.85) 88.79% 89.37% 552s 770s 256

通过上表可以看出来,在精度恢复方面,知识蒸馏的微调方式要比普通的微调方式效果更好,但是使用知识蒸馏的微调方式所需的微调时间要更长,这是因为知识蒸馏过程不仅要计算学生模型的前向传播,还要计算教师模型的前向传播。当把m设为256时,剪枝所花的时间却高了大约十倍左右(相比m=16时),然而微调之后的精度却仅仅提高了0.1%~%0.2左右,这与花费的时间相比,这不是一个令人满意的结果。