Code for "CP-ViT: Cascade Vision Transformer Pruning via Progressive Sparsity Prediction" on CIFAR-10/100.
CP-ViT: a cascade pruning framework named CP-ViT by predicting sparsity in ViT models progressively and dynamically to reduce computational redundancy while minimizing the accuracy loss. Specifically, we define the cumulative score to reserve the informative patches and heads across the ViT model for better accuracy. We also propose the dynamic pruning ratio adjustment technique based on layer-aware attention range. CP-ViT has great general applicability for practical deployment, which can be applied to a wide range of ViT models and can achieve superior accuracy with or without finetuning.
We have tested our codes under the following environments:
python == 3.9.5
pytorch == 1.9.0
torchvision == 0.10.0
CUDA == 11.2
To start with, you can first download pre-trained models from:
and place them under folder./CP-ViT/output/
.
Of course you can download other pre-trained models from Google Cloud .
We then prune ViT model without finetuning by:
python3 eval.py \
--name="CP-ViT test" \
--dataset="cifar10" \
--model_type="ViT-B_16" \
--pretrained_dir='output/cifar10_checkpoint.pth' \
--eval_batch_size=64
We can finetune the CP-ViT model by:
python3 train.py \
--name="CP-ViT finetune" \
--dataset="cifar10" \
--model_type="ViT-B_16" \
--pretrained_dir='output/cifar10_checkpoint.pth' \
--train_batch_size=64 \
--eval_every=3125 \
--learning_rate=3e-2 \
--num_steps=10000 \
--decay_type="cosine"
Pytorch Image Models: https://github.com/rwightman/pytorch-image-models