/ComCAT

Primary LanguageJupyter Notebook

COMCAT: Towards Efficient Compression and Customization of Attention-Based Vision Models [PMLR] [arXiv]

@InProceedings{pmlr-v202-xiao23e,
  title = 	 {{COMCAT}: Towards Efficient Compression and Customization of Attention-Based Vision Models},
  author =       {Xiao, Jinqi and Yin, Miao and Gong, Yu and Zang, Xiao and Ren, Jian and Yuan, Bo},
  booktitle = 	 {Proceedings of the 40th International Conference on Machine Learning},
  pages = 	 {38125--38136},
  year = 	 {2023},
  editor = 	 {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan},
  volume = 	 {202},
  series = 	 {Proceedings of Machine Learning Research},
  month = 	 {23--29 Jul},
  publisher =    {PMLR},
  pdf = 	 {https://proceedings.mlr.press/v202/xiao23e/xiao23e.pdf},
  url = 	 {https://proceedings.mlr.press/v202/xiao23e.html}
}

What is ComCAT?

ComCAT explores an efficient method for compressing vision transformers to enrich the toolset for obtaining compact attention-based vision models. Based on the new insight on the multi-head attention layer, we develop a highly efficient ViT compression solution, which outperforms the state-of-the-art pruning methods, bringing highly efficient low-rank attention-based vision model compression solution.

For compressing DeiTsmall and DeiT-base models on ImageNet, ComCAT can achieve 0.45% and 0.76% higher top-1 accuracy even with fewer parameters. ComCAT can also be applied to improve the customization efficiency of text-to-image diffusion models, with much faster training (up to 2.6× speedup) and lower extra storage cost (up to 1927.5× reduction) than the existing works.

Vision Transformer Compression

Compressing vision transformer using low-rank MHA layers and automatic rank selection.

Setup

conda install -c pytorch pytorch torchvision
pip install timm==0.5.4

Training

# DeiT-small

wget https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth

python main.py --model deit_small_patch16_224 --data-path /path/to/imagenet/ --batch-size 512 --load deit_small_patch16_224-cd65a155.pth  --output_dir small_auto  --epochs 30 --warmup-epochs 0 --search-rank --distillation-type hard --teacher-model deit_small_patch16_224 --teacher-path  https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth --with-align --distillation-without-token --batch-size-search 64 --target-params-reduction 0.5 > small_auto.log

python -m torch.distributed.launch --nproc_per_node=4 --use_env  main.py --model deit_small_patch16_224 --data-path /path/to/imagenet/ --batch-size 256 --finetune-rank-dir small_auto  --output_dir small_auto_finetune --warmup-epochs 0 --distillation-type hard --teacher-model deit_small_patch16_224 --teacher-path  https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth --with-align --distillation-without-token --attn2-with-bias  --lr 1e-4 --min-lr 5e-6 --weight-decay 5e-3 > small_auto_finetune.log


# DeiT-Base

wget https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth

python main.py --model deit_base_patch16_224 --data-path /path/to/imagenet/ --batch-size 320 --load deit_base_patch16_224-b5f2ef4d.pth --output_dir base_auto  --epochs 30 --warmup-epochs 0 --search-rank --distillation-type hard --teacher-model deit_base_patch16_224 --teacher-path  https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth --with-align --distillation-without-token --batch-size-search 16 --target-params-reduction 0.6 > base_auto.log

python -m torch.distributed.launch --nproc_per_node=4 --use_env  main.py --model deit_base_patch16_224 --data-path /path/to/imagenet/ --batch-size 256  --output_dir base_auto_finetune  --lr 1e-4 --min-lr 1e-6  --finetune-rank-dir base_auto --unscale-lr --distillation-type hard --teacher-model deit_base_patch16_224 --teacher-path https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth --distillation-without-token --warmup-epochs 0 --attn2-with-bias  > base_auto_finetune.log

Inference

mkdir small_79.58_0.44
wget https://github.com/jinqixiao/ComCAT/releases/download/comcatv1.0/small_checkpoint.pth > small_79.58_0.44/checkpoint.pth
wget https://github.com/jinqixiao/ComCAT/releases/download/comcatv1.0/small_ranks.txt > small_79.58_0.44/ranks.txt
python -m torch.distributed.launch --nproc_per_node=4 --use_env  main.py --model deit_small_patch16_224 --data-path /path/to/imagenet/ --batch-size 256 --finetune-rank-dir small_79.58_0.44 --attn2-with-bias --eval

mkdir base_82.28_0.61
wget https://github.com/jinqixiao/ComCAT/releases/download/comcatv1.0/checkpoint.pth > base_82.28_0.61/checkpoint.pth
wget https://github.com/jinqixiao/ComCAT/releases/download/comcatv1.0/ranks.txt > base_82.28_0.61/ranks.txt
python -m torch.distributed.launch --nproc_per_node=4 --use_env  main.py --model deit_base_patch16_224 --data-path /path/to/imagenet/ --batch-size 256 --finetune-rank-dir base_82.28_0.61 --attn2-with-bias --eval

Fast Fine-tuning for Text-to-image Diffusion Model

Customizing text-to-image diffusion model using low-rank MHA mechanism.

Setup

pip install diffusers==0.10.2
pip install transformers==4.25.1

Generate images

Refer to finetune_goldendoodle.ipynb

Fine-tuning

sh run.sh goldendoodle