/LITv2

[NeurIPS 2022 Spotlight] This is the official PyTorch implementation of "Fast Vision Transformers with HiLo Attention"

Primary LanguagePythonApache License 2.0Apache-2.0

Fast Vision Transformers with HiLo Attention👋(NeurIPS 2022 Spotlight)

License PyTorch

This is the official PyTorch implementation of Fast Vision Transformers with HiLo Attention.

By Zizheng Pan, Jianfei Cai, and Bohan Zhuang.

News

  • 17/11/2023. Update detection scripts with mmdet v3.2.0.

  • 20/04/2023. Update training scripts with PyTorch 2.0. Support ONNX and TensorRT model conversion, see here.

  • 15/12/2022. Releasing ImageNet pretrained weights of using different values of alpha.

  • 11/11/2022. LITv2 will be presented as Spotlight!

  • 13/10/2022. Update code for higher version of timm. Compatible with PyTorch 1.12.1 + CUDA 11.3 + timm 0.6.11.

  • 30/09/2022. Add benchmarking results for single attention layer. HiLo is super fast on both CPU and GPU!

  • 15/09/2022. LITv2 is accepted by NeurIPS 2022! 🔥🔥🔥

  • 16/06/2022. We release the source code for classification/detection/segmentation, along with the pretrained weights. Any issues are welcomed!

A Gentle Introduction

hilo

We introduce LITv2, a simple and effective ViT which performs favourably against the existing state-of-the-art methods across a spectrum of different model sizes with faster speed.

hilo

The core of LITv2: HiLo attention HiLo is inspired by the insight that high frequencies in an image capture local fine details and low frequencies focus on global structures, whereas a multi-head self-attention layer neglects the characteristic of different frequencies. Therefore, we propose to disentangle the high/low frequency patterns in an attention layer by separating the heads into two groups, where one group encodes high frequencies via self-attention within each local window, and another group performs the attention to model the global relationship between the average-pooled low-frequency keys from each window and each query position in the input feature map.

A Simple Demo

To quickly understand HiLo attention, you only need to install PyTorch and try the following code in the root directory of this repo.

from hilo import HiLo
import torch

model = HiLo(dim=384, num_heads=12, window_size=2, alpha=0.5)

x = torch.randn(64, 196, 384) # batch_size x num_tokens x hidden_dimension
out = model(x, 14, 14)
print(out.shape)
print(model.flops(14, 14)) # the numeber of FLOPs

Output:

torch.Size([64, 196, 384])
83467776

Installation

Requirements

  • Linux with Python ≥ 3.6
  • PyTorch >= 1.8.1
  • timm >= 0.3.2
  • CUDA 11.1
  • An NVIDIA GPU

Conda environment setup

Note: You can use the same environment to debug LITv1. Otherwise, you can create a new python virtual environment by the following script.

conda create -n lit python=3.7
conda activate lit

# Install Pytorch and TorchVision
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html

pip install timm
pip install ninja
pip install tensorboard

# Install NVIDIA apex
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ../
rm -rf apex/

# Build Deformable Convolution
cd mm_modules/DCN
python setup.py build install

pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8

Getting Started

For image classification on ImageNet, please refer to classification.

For object detection on COCO 2017, please refer to detection.

For semantic segmentation on ADE20K, please refer to segmentation.

Results and Model Zoo

Note: For your convenience, you can find all models and logs from Google Drive (4.8G in total). Alternatively, we also provide download links from github.

Image Classification on ImageNet-1K

All models are trained with 300 epochs with a total batch size of 1024 on 8 V100 GPUs.

Model Resolution Params (M) FLOPs (G) Throughput (imgs/s) Train Mem (GB) Test Mem (GB) Top-1 (%) Download
LITv2-S 224 28 3.7 1,471 5.1 1.2 82.0 model & log
LITv2-M 224 49 7.5 812 8.8 1.4 83.3 model & log
LITv2-B 224 87 13.2 602 12.2 2.1 83.6 model & log
LITv2-B 384 87 39.7 198 35.8 4.6 84.7 model

By default, the throughput and memory footprint are tested on one RTX 3090 based on a batch size of 64. Memory is measured by the peak memory usage with torch.cuda.max_memory_allocated(). Throughput is averaged over 30 runs.

Pretrained LITv2-S with Different Values of Alpha

Alpha Params (M) Lo-Fi Heads Hi-Fi Heads FLOPs (G) ImageNet Top1 (%) Download
0.0 28 0 12 3.97 81.16 github
0.2 28 2 10 3.88 81.89 github
0.4 28 4 8 3.82 81.81 github
0.5 28 6 6 3.77 81.88 github
0.7 28 8 4 3.74 81.94 github
0.9 28 10 2 3.73 82.03 github
1.0 28 12 0 3.70 81.89 github

Pretrained weights from the experiments of Figure 4: Effect of α based on LITv2-S.

Object Detection on COCO 2017

All models are trained with 1x schedule (12 epochs) with a total batch size of 16 on 8 V100 GPUs.

RetinaNet

Backbone Window Size Params (M) FLOPs (G) FPS box AP Config Download
LITv2-S 2 38 242 18.7 44.0 config model & log
LITv2-S 4 38 230 20.4 43.7 config model & log
LITv2-M 2 59 348 12.2 46.0 config model & log
LITv2-M 4 59 312 14.8 45.8 config model & log
LITv2-B 2 97 481 9.5 46.7 config model & log
LITv2-B 4 97 430 11.8 46.3 config model & log

Mask R-CNN

Backbone Window Size Params (M) FLOPs (G) FPS box AP mask AP Config Download
LITv2-S 2 47 261 18.7 44.9 40.8 config model & log
LITv2-S 4 47 249 21.9 44.7 40.7 config model & log
LITv2-M 2 68 367 12.6 46.8 42.3 config model & log
LITv2-M 4 68 315 16.0 46.5 42.0 config model & log
LITv2-B 2 106 500 9.3 47.3 42.6 config model & log
LITv2-B 4 106 449 11.5 46.8 42.3 config model & log

Semantic Segmentation on ADE20K

All models are trained with 80K iterations with a total batch size of 16 on 8 V100 GPUs.

Backbone Params (M) FLOPs (G) FPS mIoU Config Download
LITv2-S 31 41 42.6 44.3 config model & log
LITv2-M 52 63 28.5 45.7 config model & log
LITv2-B 90 93 27.5 47.2 config model & log

Benchmarking Throughput on More GPUs

Model Params (M) FLOPs (G) A100 V100 RTX 6000 RTX 3090 Top-1 (%)
ResNet-50 26 4.1 1,424 1,123 877 1,279 80.4
PVT-S 25 3.8 1,460 798 548 1,007 79.8
Twins-PCPVT-S 24 3.8 1,455 792 529 998 81.2
Swin-Ti 28 4.5 1,564 1,039 710 961 81.3
TNT-S 24 5.2 802 431 298 534 81.3
CvT-13 20 4.5 1,595 716 379 947 81.6
CoAtNet-0 25 4.2 1,538 962 643 1,151 81.6
CaiT-XS24 27 5.4 991 484 299 623 81.8
PVTv2-B2 25 4.0 1,175 670 451 854 82.0
XCiT-S12 26 4.8 1,727 761 504 1,068 82.0
ConvNext-Ti 28 4.5 1,654 762 571 1,079 82.1
Focal-Tiny 29 4.9 471 372 261 384 82.2
LITv2-S 28 3.7 1,874 1,304 928 1,471 82.0

Single Attention Layer Benchmark

The following visualization results can refer to vit-attention-benchmark.

hilo_cpu_gpu

Citation

If you use LITv2 in your research, please consider the following BibTeX entry and giving us a star 🌟.

@inproceedings{pan2022hilo,
  title={Fast Vision Transformers with HiLo Attention},
  author={Pan, Zizheng and Cai, Jianfei and Zhuang, Bohan},
  booktitle={NeurIPS},
  year={2022}
}

If you find the code useful, please also consider the following BibTeX entry

@inproceedings{pan2022litv1,
  title={Less is More: Pay Less Attention in Vision Transformers},
  author={Pan, Zizheng and Zhuang, Bohan and He, Haoyu and Liu, Jing and Cai, Jianfei},
  booktitle={AAAI},
  year={2022}
}

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgement

This repository is built upon DeiT, Swin and LIT, we thank the authors for their open-sourced code.