/PyTorch-Scratch-Vision-Transformer-ViT

Simple and easy to understand PyTorch implementation of Vision Transformer (ViT) from scratch with detailed steps. Tested on small datasets: MNIST, FashionMNIST, SVHN, CIFAR10, and CIFAR100.

Primary LanguagePythonMIT LicenseMIT

Vision Transformer from Scratch in PyTorch

Simplified Scratch Pytorch Implementation of Vision Transformer (ViT) with Detailed Steps (Refer to model.py)

This repo uses a smaller ViT for small-scale datasets like MNIST, CIFAR10, etc., using a smaller patch size.

Key Points:

  • ViT used in a scaled-down version of the original ViT architecture from An Image is Worth 16X16 Words.
  • Has only 200k-800k parameters depending upon the embedding dimension (Original ViT-Base has 86 million).
  • Works with small datasets by using a smaller patch size of 4.
  • Supported datasets: MNIST, FashionMNIST, SVHN, and CIFAR10.



Run commands (also available in scripts.sh):

Dataset Run command Test Acc
MNIST python main.py --dataset mnist --epochs 100 99.5
Fashion MNIST python main.py --dataset fmnist 92.3
SVHN python main.py --dataset svhn --n_channels 3 --image_size 32 --embed_dim 128 96.2
CIFAR10 python main.py --dataset cifar10 --n_channels 3 --image_size 32 --embed_dim 128 86.3 (82.5 w/o RandAug)
CIFAR100 python main.py --dataset cifar100 --n_channels 3 --image_size 32 --embed_dim 128 59.6 (55.8 w/o RandAug)



Transformer Config:

Config MNIST and FMNIST SVHN and CIFAR
Input Size 1 X 28 X 28 3 X 32 X 32
Patch Size 4 4
Sequence Length 7*7 = 49 8*8 = 64
Embedding Size 64 128
Parameters 210k 820k
Num of Layers 6 6
Num of Heads 4 4
Forward Multiplier 2 2
Dropout 0.1 0.1

To train Vision Transformer with a different type of position embeddings, check out Positional Embeddings for Vision Transformers