Created by Yongming Rao, Wenliang Zhao, Benlin Liu, Jiwen Lu, Jie Zhou, Cho-Jui Hsieh
This repository contains PyTorch implementation for DynamicViT.
We introduce a dynamic token sparsification framework to prune redundant tokens in vision transformers progressively and dynamically based on the input:
Our code is based on pytorch-image-models, DeiT and LV-ViT
We provide our DynamicViT models pretrained on ImageNet:
name | arch | rho | acc@1 | acc@5 | FLOPs | url |
---|---|---|---|---|---|---|
DynamicViT-256/0.7 | deit_256 |
0.7 | 76.532 | 93.118 | 1.3G | Google Drive / Tsinghua Cloud |
DynamicViT-384/0.7 | deit_small |
0.7 | 79.316 | 94.676 | 2.9G | Google Drive / Tsinghua Cloud |
DynamicViT-LV-S/0.5 | lvvit_s |
0.5 | 81.970 | 95.756 | 3.7G | Google Drive / Tsinghua Cloud |
DynamicViT-LV-S/0.7 | lvvit_s |
0.7 | 83.076 | 96.252 | 4.6G | Google Drive / Tsinghua Cloud |
DynamicViT-LV-M/0.7 | lvvit_m |
0.7 | 83.816 | 96.584 | 8.5G | Google Drive / Tsinghua Cloud |
- torch>=1.7.0
- torchvision>=0.8.1
- timm==0.4.5
Data preparation: download and extract ImageNet images from http://image-net.org/. The directory structure should be
│ILSVRC2012/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
Model preparation: download pre-trained DeiT and LV-ViT models for training DynamicViT:
sh download_pretrain.sh
We provide a Jupyter notebook where you can run the visualization of DynamicViT.
To run the demo, you need to install matplotlib
.
To evaluate a pre-trained DynamicViT model on the ImageNet validation set with a single GPU, run:
python infer.py --data-path /path/to/ILSVRC2012/ --arch arch_name --model-path /path/to/model --base_rate 0.7
To train DynamicViT models on ImageNet, run:
DeiT-small
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py --output_dir logs/dynamic-vit_deit-small --arch deit_small --input-size 224 --batch-size 96 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7
LV-ViT-S
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py --output_dir logs/dynamic-vit_lvvit-s --arch lvvit_s --input-size 224 --batch-size 64 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7
LV-ViT-M
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py --output_dir logs/dynamic-vit_lvvit-m --arch lvvit_m --input-size 224 --batch-size 48 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7
You can train models with different keeping ratio by adjusting base_rate
. DynamicViT can also achieve comparable performance with only 15 epochs training (around 0.1% lower accuracy compared to 30 epochs).
MIT License
If you find our work useful in your research, please consider citing:
@article{rao2021dynamicvit,
title={DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification},
author={Rao, Yongming and Zhao, Wenliang and Liu, Benlin and Lu, Jiwen and Zhou, Jie and Hsieh, Cho-Jui},
journal={arXiv preprint arXiv:2106.02034},
year={2021}
}
GFLOPs | Acc | |
---|---|---|
0.9 | 3.9 | 79.5 |
0.8 | 3.4 | 78.8 |
0.7 | 2.9 | 78.1 |
GFLOPs | Acc | |
---|---|---|
0.9 | 3.9 | |
0.8 | 3.4 | |
0.7 | 2.9 |