Pytorch implementation of our paper: Sliced Recursive Transformer, Zhiqiang Shen, Zechun Liu and Eric Xing, CMU and MBZUAI.
We present a neat yet effective recursive operation on vision transformers that can improve parameter utilization without involving additional parameters. This is achieved by sharing weights across depth of transformer networks. The proposed method can obtain a substantial gain (about 2%) simply using naive recursive operation, requires no special or sophisticated knowledge for designing principles of networks, and introduces minimum computational overhead to the training procedure. To reduce the additional computation caused by recursive operation while maintaining the superior accuracy, we propose an approximating method through multiple sliced group self-attentions across recursive layers which can reduce the cost consumption by 10~30% with minimal performance loss. We call our model Sliced Recursive Transformer (SReT), which is compatible with a broad range of other designs for efficient vision transformers. Our best model establishes significant improvement on ImageNet-1K over state-of-the-art methods while containing fewer parameters.
Install timm
using:
pip install git+https://github.com/rwightman/pytorch-image-models.git
Create SReT models:
import torch
import SReT
model = SReT.SReT_S(pretrained=False)
print(model(torch.randn(1, 3, 224, 224)))
...
Load pre-trained SReT models:
import torch
import SReT
model = SReT.SReT_S(pretrained=False)
model.load_state_dict(torch.load('./pre-trained/SReT_S.pth')['model'])
print(model(torch.randn(1, 3, 224, 224)))
...
Train SReT models with knowledge distillation (recommend to train with FKD):
import torch
import
import SReT
import kd_loss
criterion_kd = kd_loss.KDLoss()
model = SReT.SReT_S_distill(pretrained=False)
student_outputs = model(images)
...
# we use the soft label only for distillation procedure as MEAL V2
# Note that 'student_outputs' and 'teacher_outputs' are logits before softmax
loss = criterion_kd(student_outputs/T, teacher_outputs/T)
...
We currently provide the last epoch checkpoints and will add the best ones together with more models soon. (⋇ indicates without slice.)
Model | FLOPs | #params | accuracy | weights (last) | weights (best) | logs | configurations |
---|---|---|---|---|---|---|---|
SReT_⋇T |
1.4G | 4.8M | 76.1 | link | TBA | link | link |
SReT_T |
1.1G | 4.8M | 76.0 | link | TBA | link | link |
SReT_⋇LT |
1.4G | 5.0M | 76.8 | link | TBA | link | link |
SReT_LT [16-14-1,1-1-1] |
1.2 G | 5.0M | 76.6 | link | TBA | link | link |
SReT_⋇S |
4.7G | 20.9M | 82.0 | link | TBA | link | link |
SReT_S |
4.2G | 20.9M | 81.9 | link | TBA | link | link |
SReT_⋇T_Distill |
1.4G | 4.8M | 77.7 | link | TBA | link | link |
SReT_T_Distill |
1.1G | 4.8M | 77.6 | link | TBA | link | link |
SReT_⋇LT_Distill |
1.4G | 5.0M | 77.9 | link | TBA | link | link |
SReT_LT_Distill |
1.2G | 5.0M | 77.7 | link | TBA | link | link |
SReT_⋇T_Distill_Finetune384 |
6.4G | 4.9M | 79.7 | link | TBA | link | link |
SReT_⋇S_Distill_Finetune384 |
18.5G | 21.0M | 83.8 | link | TBA | link | link |
SReT_⋇S_Distill_Finetune512 |
42.8G | 21.3M | 84.3 | link | TBA | link | link |
If you find our code is helpful for your research, please cite:
@article{shen2021sliced,
title={Sliced Recursive Transformer},
author={Zhiqiang Shen and Zechun Liu and Eric Xing},
year={2021},
journal={arXiv preprint arXiv:2111.05297}
}
Zhiqiang Shen (zhiqians at andrew.cmu.edu or zhiqiangshen0214 at gmail.com)