This repository hosts code for converting the original Vision Transformer models [1] (JAX) to TensorFlow.
The original models were fine-tuned on the ImageNet-1k dataset [2]. For more details
on the training protocols, please follow [3]. The authors of [3] open-sourced about
50k different variants of Vision Transformer models in JAX. Using the
conversion.ipynb
notebook, one should be able to take a model from the pool of models and convert that
to TensorFlow and use that with TensorFlow Hub and Keras.
The original model classes and weights [4] were converted using the jax2tf
tool [5].
Note that it's a requirement to use TensorFlow 2.6 or greater to use the converted models.
Find the model collection on TensorFlow Hub: https://tfhub.dev/sayakpaul/collections/vision_transformer/1.
Eight best performing ImageNet-1k models have also been made available on TensorFlow
Hub that can be used either for off-the-shelf image classification or transfer learning.
Please follow the model-selector.ipynb
notebook to understand how these models were chosen.
The table below provides a performance summary:
Model | Top-1 Accuracy | Checkpoint | Misc |
---|---|---|---|
B/8 | 85.948 | B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz | |
L/16 | 85.716 | L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz | |
B/16 | 84.018 | B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz | |
R50-L/32 | 83.784 | R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz | |
R26-S/32 (light aug) | 80.944 | R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz | tb.dev run |
R26-S/32 (medium aug) | 80.462 | R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz | |
S/16 | 80.462 | S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz | tb.dev run |
B/32 | 79.436 | B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz |
Note that the top-1 accuracy is reported on ImageNet-1k validation set. The checkpoints are present in the following GCS
location: gs://vit_models/augreg
. More details on these can be found in [4].
- ViT-S16
- ViT-B8
- ViT-B16
- ViT-B32
- ViT-L16
- ViT-R26-S32 (light augmentation)
- ViT-R26-S32 (medium augmentation)
- ViT-R50-L32
- ViT-S16
- ViT-B8
- ViT-B16
- ViT-B32
- ViT-L16
- ViT-R26-S32 (light augmentation)
- ViT-R26-S32 (medium augmentation)
- ViT-R50-L32
classification.ipynb
: Shows how to load a Vision Transformer model from TensorFlow Hub and run image classification.fine_tune.ipynb
: Shows how to fine-tune a Vision Transformer model from TensorFlow Hub on thetf_flowers
dataset.
Additionally, i1k_eval
contains files for running
evaluation on ImageNet-1k validation
split.
[1] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Dosovitskiy et al.
[2] ImageNet-1k
[5] jax2tf tool
Thanks to the authors of Vision Transformers for their efforts put into open-sourcing the models.
Thanks to the ML-GDE program for providing GCP Credit support that helped me execute the experiments for this project.