/PyTorch-Scratch-Vision-Transformer-ViT

Simple and easy to understand PyTorch implementation of Vision Transformer (ViT) from scratch, with detailed steps. Tested on common datasets like MNIST, CIFAR10, and more.

Primary LanguagePythonMIT LicenseMIT

Vision Transformer from Scratch in PyTorch

Simplified Scratch Pytorch Implementation of Vision Transformer (ViT) with Detailed Steps (Refer to model.py)

This repo uses a smaller ViT for small-scale datasets like MNIST, CIFAR10, etc., using a smaller patch size.

Key Points:

  • ViT used in a scaled-down version of the original ViT architecture from An Image is Worth 16X16 Words.
  • Has only 200k-800k parameters depending upon the embedding dimension (Original ViT-Base has 86 million).
  • Works with small datasets by using a smaller patch size of 4.
  • Supported datasets: MNIST, FashionMNIST, SVHN, and CIFAR10.



Run commands (also available in scripts.sh):

Dataset Run command Test Acc
MNIST python main.py --dataset mnist --epochs 100 99.5
Fashion MNIST python main.py --dataset fmnist 92.3
SVHN python main.py --dataset svhn --n_channels 3 --image_size 32 --embed_dim 128 96.2
CIFAR10 python main.py --dataset cifar10 --n_channels 3 --image_size 32 --embed_dim 128 86.3 (82.5 w/o RandAug)
CIFAR100 python main.py --dataset cifar100 --n_channels 3 --image_size 32 --embed_dim 128 59.6 (55.8 w/o RandAug)



Transformer Config:

Config MNIST and FMNIST SVHN and CIFAR
Input Size 1 X 28 X 28 3 X 32 X 32
Patch Size 4 4
Sequence Length 7*7 = 49 8*8 = 64
Embedding Size 64 128
Parameters 210k 820k
Num of Layers 6 6
Num of Heads 4 4
Forward Multiplier 2 2
Dropout 0.1 0.1

To train Vision Transformer with a different type of position embeddings, check out Positional Embeddings for Vision Transformers