/STViT-R-Object-Detection

This is an official implementation of "Making Vision Transformers Efficient from A Token Sparsification View"

Primary LanguagePythonMIT LicenseMIT

STViT-R for Detection

This is an official implementation for Making Vision Transformers Efficient from A Token Sparsification View. It is based on Swin Transformer.

Notes: We will further clean the code and release the checkpoints in the future.

Results on COCO

Model $AP^b$ $AP^b_{50}$ $AP^b_{75}$ $AP^b_s$ $AP^m$ $AP^m_{50}$ $AP^m_{75}$ $AP^m_{s}$ log
STViT-R-Swin-S 51.8 70.6 56.1 36.7 44.7 67.8 48.6 29.0 Link
STViT-R-Swin-B 52.2 70.8 56.8 36.5 45.2 68.3 49.1 29.5 Link

Usage

Installation

Please refer to MMDetection for installation and dataset preparation.

We use apex for mixed precision training by default. To install apex, run:

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Training

To train a Cascade Mask R-CNN model with a STVIT-R-Swin-S backbone and 8 gpus, run:

tools/dist_train.sh configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py 8 --cfg-options model.pretrained=<PRETRAIN_MODEL> 

To train a Cascade Mask R-CNN model with a STVIT-R-Swin-B backbone and 8 gpus, run:

tools/dist_train.sh configs/swin/cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py 8 --cfg-options model.pretrained=<PRETRAIN_MODEL> 

Citing STViT-R

@inproceedings{chang2023making,
  title={Making Vision Transformers Efficient from A Token Sparsification View},
  author={Chang, Shuning and Wang, Pichao and Lin, Ming and Wang, Fan and Zhang, David Junhao and Jin, Rong and Shou, Mike Zheng},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={6195--6205},
  year={2023}
}