This repository implements Vision transformer on a synthetic dataset of mnist colored numbers on textures/solid background .
Building Vision Transformer Video
Sample from dataset
For setting up the mnist dataset: Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
Download Quarter RGB resolution texture data from ALOT Homepage In case you want to train on higher resolution, you can download that as well and but you would have to create new imdb.json Rest of the code should work fine as long as you create valid json files.
Download imdb.json from Drive Verify the data directory has the following structure after textures download
VIT-Pytorch/data/textures/{texture_number}
*.png
VIT-Pytorch/data/train/images/{0/1/.../9}
*.png
VIT-Pytorch/data/test/images/{0/1/.../9}
*.png
VIT-Pytorch/data/imdb.json
- Create a new conda environment with python 3.8 then run below commands
git clone https://github.com/explainingai-code/VIT-Pytorch.git
cd VIT-Pytorch
pip install -r requirements.txt
python -m tools.train
for training vitpython -m tools.inference
for running inference, attention visualizations and positional embedding plots
config/default.yaml
- Allows you to play with different aspects of VIT
Outputs will be saved according to the configuration present in yaml files.
For every run a folder of task_name
key in config will be created
- Best Model checkpoint in
task_name
directory
During inference the following output will be saved
- Attention map visualization for sample of test set in
task_name/output
- Positional embedding similarity plots in
task_name/output/position_plot.png
Following is a sample attention map that you should get
Here is a positional embedding similarity plot you should get