Simplified Scratch Pytorch Implementation of Vision Transformer (ViT) with Detailed Steps (Refer to model.py)
This repo uses a smaller ViT for small-scale datasets like MNIST, CIFAR10, etc., using a smaller patch size.
Key Points:
- ViT used in a scaled-down version of the original ViT architecture from An Image is Worth 16X16 Words.
- Has only 200k-800k parameters depending upon the embedding dimension (Original ViT-Base has 86 million).
- Works with small datasets by using a smaller patch size of 4.
- Supported datasets: MNIST, FashionMNIST, SVHN, and CIFAR10.
Run commands (also available in scripts.sh):
Dataset | Run command | Test Acc |
---|---|---|
MNIST | python main.py --dataset mnist --epochs 100 | 99.5 |
Fashion MNIST | python main.py --dataset fmnist | 92.3 |
SVHN | python main.py --dataset svhn --n_channels 3 --image_size 32 --embed_dim 128 | 96.2 |
CIFAR10 | python main.py --dataset cifar10 --n_channels 3 --image_size 32 --embed_dim 128 | 86.3 (82.5 w/o RandAug) |
CIFAR100 | python main.py --dataset cifar100 --n_channels 3 --image_size 32 --embed_dim 128 | 59.6 (55.8 w/o RandAug) |
Config | MNIST and FMNIST | SVHN and CIFAR |
---|---|---|
Input Size | 1 X 28 X 28 | 3 X 32 X 32 |
Patch Size | 4 | 4 |
Sequence Length | 7*7 = 49 | 8*8 = 64 |
Embedding Size | 64 | 128 |
Parameters | 210k | 820k |
Num of Layers | 6 | 6 |
Num of Heads | 4 | 4 |
Forward Multiplier | 2 | 2 |
Dropout | 0.1 | 0.1 |
To train Vision Transformer with a different type of position embeddings, check out Positional Embeddings for Vision Transformers