Implementation of Vision Transformer in PyTorch, a new model to achieve SOTA in vision classification with using transformer style encoders. Associated blog article.
- Vanilla ViT
- Hybrid ViT (with support for BiTResNets as backbone)
- Hybrid ViT (with support for AxialResNets as backbone)
- Training Scripts
To Do:
- Training Script
- Support for linear decay
- Correct hyper parameters
- Full Axial-ViT
- Results for Imagenet-1K and Imagenet-21K
Create the environment:
conda env create -f environment.yml
Preparing the dataset:
mkdir data
cd data
ln -s path/to/dataset imagenet
For non-distributed training:
python train.py --model ViT --name vit_logs
For distributed training:
CUDA_VISIBLE_DEVICES=0,1,2,3 python dist_train.py --model ViT --name vit_dist_logs
For testing add the --test
parameter:
python train.py --model ViT --name vit_logs --test
CUDA_VISIBLE_DEVICES=0,1,2,3 python dist_train.py --model ViT --name vit_dist_logs --test
- BiTResNet: https://github.com/google-research/big_transfer/tree/master/bit_pytorch
- AxialResNet: https://github.com/csrhddlam/axial-deeplab
- Training Scripts: https://github.com/csrhddlam/axial-deeplab
@inproceedings{
anonymous2021an,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={Anonymous},
booktitle={Submitted to International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=YicbFdNTTy},
note={under review}
}