Vision Transformer

Authors: Wooseok Gwak

This is the repository that implements Vision Transformer (Alexey Dosovitskiy et al, 2021) using tensorflow 2. The paper can be found at here.

The official Jax repository is here

Requirements

I recommend using python3 and conda virtual environment.

conda create -n myenv python=3.7
conda activate myenv
conda install --yes --file requirements.txt

After making a virtual environment, download the git repository and use the model for your own project. When you're done working on the project, deactivate the virtual environment with conda deactivate.

Usage

import tensorflow as tf
from model.model import model

vit = ViT(
    d_model = 50
    mlp_dim = 100,
    num_heads = 10,
    dropout_rate = 0.1,
    num_layers = 3,
    patch_size = 32,
    num_classes = 102
)

img = np.randn(1, 3, 256, 256)

preds = vit(img)

Because of dependeny problem for anaconda packages, I use tensorflow 2.3 and write the code for multi head attention. (the code can be found from here) I recommend to use tf.keras.layers.MultiHeadAttention from tensorflow 2.5~.

Training

python train.py

train.py is sample training code to verify whether it performs the desired operation. You can change the file to train the model on specific dataset.

  • 2021.11.30 : WARNING:tensorflow:'gradients do not exist for variables' is not resolved!