/LIT

[AAAI 2022] This is the official PyTorch implementation of "Less is More: Pay Less Attention in Vision Transformers"

Primary LanguagePythonApache License 2.0Apache-2.0

Less is More: Pay Less Attention in Vision Transformers

License PyTorch

This is the official PyTorch implementation of AAAI 2022 paper: Less is More: Pay Less Attention in Vision Transformers.

By Zizheng Pan, Bohan Zhuang, Haoyu He, Jing Liu and Jianfei Cai.

In our paper, we present a novel Less attention vIsion Transformer (LIT), building upon the fact that the early self-attention layers in Transformers still focus on local patterns and bring minor benefits in recent hierarchical vision Transformers. LIT uses pure multi-layer perceptrons (MLPs) to encode rich local patterns in the early stages while applying self-attention modules to capture longer dependencies in deeper layers. Moreover, we further propose a learned deformable token merging module to adaptively fuse informative patches in a non-uniform manner.

If you use this code for a paper please cite:

@inproceedings{pan2022litv1,
  title={Less is More: Pay Less Attention in Vision Transformers},
  author={Pan, Zizheng and Zhuang, Bohan and He, Haoyu and Liu, Jing and Cai, Jianfei},
  booktitle = {AAAI},
  year={2022}
}

Updates

  • 19/06/2022. We introduce LITv2, a faster and better Vision Transformer with a novel efficient HiLo attention. Code and pretrained weights have also been released here.

  • 10/03/2022. Add visualisation code for attention maps in Figure 3. Please refer to here.

Usage

First, clone this repository.

git clone git@github.com:ziplab/LIT.git

Next, create a conda virtual environment.

# Make sure you have a NVIDIA GPU.
cd LIT/classification
bash setup_env.sh [conda_install_path] [env_name]

# For example
bash setup_env.sh /home/anaconda3 lit

Note: We use PyTorch 1.7.1 with CUDA 10.1 for all experiments. The setup_env.sh has illustrated all dependencies we used in our experiments. You may want to edit this file to install a different version of PyTorch or any other packages.

Image Classification on ImageNet

We provide baseline LIT models pretrained on ImageNet-1K. For training and evaluation code, please refer to classification.

Name Params (M) FLOPs (G) Top-1 Acc. (%) Model Log
LIT-Ti 19 3.6 81.1 google drive/github log
LIT-S 27 4.1 81.5 google drive/github log
LIT-M 48 8.6 83.0 google drive/github log
LIT-B 86 15.0 83.4 google drive/github log

Object Detection on COCO

For training and evaluation code, please refer to detection.

RetinaNet

Backbone Params (M) Lr schd box mAP Config Model Log
LIT-Ti 30 1x 41.6 config github log
LIT-S 39 1x 41.6 config github log

Mask R-CNN

Backbone Params (M) Lr schd box mAP mask mAP Config Model Log
LIT-Ti 40 1x 42.0 39.1 config github log
LIT-S 48 1x 42.9 39.6 config github log

Semantic Segmentation on ADE20K

For training and evaluation code, please refer to segmentation.

Semantic FPN

Backbone Params (M) Iters mIoU Config Model Log
LIT-Ti 24 8k 41.3 config github log
LIT-S 32 8k 41.7 config github log

Offsets Visualisation

dpm_vis

We provide a script for visualising the learned offsets by the proposed deformable token merging modules (DTM). For example,

# activate your virtual env
conda activate lit
cd classification/code_for_lit_ti

# visualise
python visualize_offset.py --model lit_ti --resume [path/to/lit_ti.pth] --vis_image visualization/demo.JPEG

The plots will be automatically saved under visualization/, with a folder named by the name of the example image.

Attention Map Visualisation

We provide our method for visualising the attention maps in Figure 3. To save your time, we also provide the pretrained model for PVT with standard MSA in all stages.

Name Params (M) FLOPs (G) Top-1 Acc. (%) Model Log
PVT w/ MSA 20 8.4 80.9 github log
conda activate lit
cd classification/code_for_lit_ti

# visualise
# by default, we save the results under 'classification/code_for_lit_ti/attn_results'
python generate_attention_maps.py --data-path [/path/to/imagenet] --resume [/path/to/pvt_full_msa.pth]

The resulting folder contains the following items,

.
├── attention_map
│   ├── stage-0
│   │   ├── block0
│   │   │   └── pixel-1260-block-0-head-0.png
│   │   ├── block1
│   │   │   └── pixel-1260-block-1-head-0.png
│   │   └── block2
│   │       └── pixel-1260-block-2-head-0.png
│   ├── stage-1
│   ├── stage-2
│   └── stage-3
└── full_msa_eval_maps.npy

where full_msa_eval_maps.npy contains the saved attention maps in each block and each stage. The folder attention_map contains the visualisation results.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgement

This repository has adopted codes from DeiT, PVT and Swin, we thank the authors for their open-sourced code.