paper link: https://arxiv.org/abs/2312.10376
This repository contains the official PyTorch implementation for SA2VP.
We use the framework from https://github.com/microsoft/unilm/tree/master/beit
we use following datasets for evaluation:
https://github.com/KMnP/vpt (FGVC)
https://github.com/dongzelian/SSF (VTAB-1k)
https://github.com/shikiw/DAM-VP (HTA)
This code is tested with Python-3.7.13, Pytorch = 1.12.1 and CUDA = 11.4, requiring the following dependencies:
- timm = 0.6.7
we also provide the requirement.txt for reference.
-
./backbone_ckpt
: save the ViT and Swin Transformer pre-trained ckpt. -
./data
: download and setup input datasets, containing fgvc and vtab-1k.
│SA2VP/
├──data/
│ ├──fgvc/
│ │ ├──CUB_200_2011/
│ │ ├──OxfordFlower/
│ │ ├──Stanford-cars/
│ │ ├──Stanford-dogs/
│ │ ├──nabirds/
│ ├──vtab-1k/
│ │ ├──caltech101/
│ │ ├──cifar/
│ │ ├──.......
├──backbone_ckpt/
│ ├──imagenet21k_ViT-B_16.npz
│ ├──swin_base_patch4_window7_224_22k.pth
-
./model_save
: save the final ckpt. -
./log_save
: save the log. -
./vpt_main
: we use the VPT code to initialize model.-
👉
./vpt_main/src/models/vit_backbones/vit_tinypara.py
: SA2VP based on ViT backbone. -
👉
./vpt_main/src/models/vit_backbones/vit_tinypara_acc.py
: We have accelerated the attention calculation of SA2VP. -
👉
./vpt_main/src/models/vit_backbones/swin_transformer_tinypara.py
: SA2VP based on Swin Transformer backbone. -
./vpt_main/src/models/build_swin_backbone.py
: package SA2VP based on Swin. In this file, it will import model in swin_transformer_tinypara.py.
-
-
datasets.py
: contain all datasets. -
engine_for_train.py
: engine for train and test. -
👉
vit_train_sa2vp.py
: call this to train SA2VP based on ViT. In line 37, you can use the accelerated version by adding '_acc' to the model name. -
👉
vit_train_swin.py
: call this to train SA2VP based on Swin Transformer. -
👉
Train_nature.sh/Train_special.sh/Train_struct.sh
: scripts used for automatic training.
-
1\ Download the pre-trained ckpt of ViT and Swin from VPT. Use ViT-B/16 Supervised and Swin-B Supervised.
-
2\ Change the name and path in
vit_train_sa2vp.py
line 48 and invit_train_swin.py
line 47. -
3\ Set different branch training weights in
engine_for_train.py
line 26/177. -
4\ Set datasets path in
datasets.py
line 1160/1161 (prefix_fgvc/prefix_vtab). Note that you need to choose transform for fgvc or vtab in line 1157/1158 and Pay attention to the dataset name in the following. -
5\ Change model config. For SA2VP based on ViT, we set inter-dim in
vit_tinypara.py
line 280/281/334/428 and inter-weight in line 427. For SA2VP based on Swin, set inter-dim invit_train_swin.py
line 169/170/675 and inter-weight in line 596. Default lr 1e-3 and weight_decay 1e-4. -
For ViT: (vtab: SVHN-16-0.5; Resisc45-16-0.5; ds/ori-16-0.1; sn/ele-32-0.5 need to Specially handle. || vtab special lr: Pets-5e-4; Clevr/Count-5e-4.)
CUB | Nabirds | Flower | DOG | CAR | |
---|---|---|---|---|---|
inter-dim | 16 | 32 | 8 | 32 | 64 |
inter-weight | 0.1 | 0.1 | 0.1 | 0.1 | 1.5 |
batch size | 64/128 | 64/128 | 64/128 | 64/128 | 64/128 |
vtab-Natural | vtab-Special | vtab-Structure | HTA | |
---|---|---|---|---|
inter-dim | 8 | 16 | 32 | 64 |
inter-weight | 0.1 | 1.5 | 1.5 | 0.1 |
batch size | 40/64 | 40 | 40 | 64/128 |
- For Swin:
vtab-Natural | vtab-Special | vtab-Structure | |
---|---|---|---|
inter-dim | 8 | 8 | 8 |
inter-weight | 0.1/0.5 | 0.5/1.5 | 1.5 |
batch size | 40/64 | 40 | 40 |
-
Training Scripts:
- Single GPU
CUDA_VISIBLE_DEVICES=1 python vit_train_sa2vp.py --data_set CUB --output_dir ./model_save/CUB --update_freq 1 --warmup_epochs 10 --epochs 100 --drop_path 0.0 --lr 1e-3 --weight_decay 1e-4 --nb_classes 200 --log_dir ./log_save --batch_size 64 --my_mode train_val --min_lr 1e-7
- Multiple GPUs
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 vit_train_sa2vp.py --data_set CIFAR --output_dir ./model_save/CIFAR --update_freq 1 --warmup_epochs 10 --epochs 100 --drop_path 0.0 --lr 1e-3 --weight_decay 1e-4 --nb_classes 100 --log_dir ./log_save --batch_size 40 --my_mode train_val --min_lr 1e-7
-
Test Script:
- For VTAB-1k
CUDA_VISIBLE_DEVICES=1 python vit_train_sa2vp.py --data_set DS_LOC --eval --batch_size 64 --resume ./model_save/DS_LOC/checkpoint-99.pth --nb_classes 16 --my_mode trainval_test
- For FGVC
CUDA_VISIBLE_DEVICES=1 python vit_train_sa2vp.py --data_set CAR --eval --batch_size 64 --resume ./model_save/CAR/checkpoint-best.pth --nb_classes 196 --my_mode trainval_test
-
Note: --my_mode is to decide train/val/test sets. In train_val: to find the best model on val set when training. In trainval_test: use train/val sets to train and report acc on test set. We follow the strategy of VPT.
If you find our work helpful in your research, please cite it as:
@inproceedings{pei2024sa2vp,
title={SA^2VP: Spatially Aligned-and-Adapted Visual Prompt},
author={Pei, Wenjie and Xia, Tongqi and Chen, Fanglin and Li, Jinsong and Tian, Jiandong and Lu, Guangming},
booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
year={2024}
}
The code is released under MIT License (see LICENSE file for details).