/SReT

Official PyTorch implementation of our ECCV 2022 paper "Sliced Recursive Transformer"

Primary LanguagePythonMIT LicenseMIT

Sliced Recursive Transformer (SReT)

Pytorch implementation of our paper: Sliced Recursive Transformer (ECCV 2022), Zhiqiang Shen, Zechun Liu and Eric Xing.

FLOPs and Params Comparison

Our Approach

  • Recursion operation:
  • Sliced Group Self-Attention:

Abstract

We present a neat yet effective recursive operation on vision transformers that can improve parameter utilization without involving additional parameters. This is achieved by sharing weights across the depth of transformer networks. The proposed method can obtain a substantial gain of about 2% simply using naive recursive operation, requires no special or sophisticated knowledge for designing principles of networks, and introduces minimal computational overhead to the training procedure. To reduce the additional computation caused by recursive operation while maintaining the superior accuracy, we propose an approximating method through multiple sliced group self-attentions across recursive layers which can reduce the cost consumption by 10~30% with minimal performance loss. We call our model Sliced Recursive Transformer (SReT), a novel and parameter-efficient vision transformer design that is compatible with a broad range of other designs for efficient ViT architectures. Our best model establishes significant improvement on ImageNet-1K over state-of-the-art methods while containing fewer parameters. The flexible scalability has shown great potential for scaling up models and constructing extremely deep vision transformers.

SReT Models

Install timm using:

pip install git+https://github.com/rwightman/pytorch-image-models.git

Create SReT models:

import torch
import SReT

model = SReT.SReT_S(pretrained=False)
print(model(torch.randn(1, 3, 224, 224)))
...

Load pre-trained SReT models:

import torch
import SReT

model = SReT.SReT_S(pretrained=False)
model.load_state_dict(torch.load('./pre-trained/SReT_S.pth')['model'])
print(model(torch.randn(1, 3, 224, 224)))
...

Train SReT models with knowledge distillation (recommend training with FKD, which is faster with higher performance):

import torch
import 
import SReT
import kd_loss

criterion_kd = kd_loss.KDLoss()

model = SReT.SReT_S_distill(pretrained=False)
student_outputs = model(images)
...
# we use the soft label only for distillation procedure as MEAL V2
# Note that 'student_outputs' and 'teacher_outputs' are logits before softmax
loss = criterion_kd(student_outputs/T, teacher_outputs/T)
...

Pre-trained Model

We currently provide the last epoch checkpoints and will add the best ones together with more models soon. (⋇ indicates without slice.) We notice that using a larger initial lr (0.001 $\times$ $batchsize \over 512$) with longer warmup epochs = 30 can obtain better results on SReT.

Model FLOPs #params accuracy weights (last) weights (best) logs configurations
SReT_⋇T 1.4G 4.8M 76.1 link TBA link link
SReT_T 1.1G 4.8M 76.0 link TBA link link
SReT_⋇LT 1.4G 5.0M 76.8 link TBA link link
SReT_LT [8-4-1,2-1-1] 1.2 G 5.0M 76.7 link TBA link link
SReT_LT [16-14-1,1-1-1] 1.2 G 5.0M 76.6 link TBA link link
SReT_⋇S 4.7G 20.9M 82.0 link TBA link link
SReT_S 4.2G 20.9M 81.9 link TBA link link
SReT_⋇T_Distill 1.4G 4.8M 77.7 link TBA link link
SReT_T_Distill 1.1G 4.8M 77.6 link TBA link link
SReT_⋇LT_Distill 1.4G 5.0M 77.9 link TBA link link
SReT_LT_Distill 1.2G 5.0M 77.7 link TBA link link
SReT_⋇T_Distill_Finetune384 6.4G 4.9M 79.7 link TBA link link
SReT_⋇S_Distill_Finetune384 18.5G 21.0M 83.8 link TBA link link
SReT_⋇S_Distill_Finetune512 42.8G 21.3M 84.3 link TBA link link

Citation

If you find our code is helpful for your research, please cite:

@article{shen2021sliced,
      title={Sliced Recursive Transformer}, 
      author={Zhiqiang Shen and Zechun Liu and Eric Xing},
      year={2021},
      journal={arXiv preprint arXiv:2111.05297}
}

Contact

Zhiqiang Shen (zhiqiangshen0214 at gmail.com or zhiqians at andrew.cmu.edu)