Vision Transformer

Introduction

ViT.png

Network for Vision Transformer. The pytorch version.

If this works for you, please give me a star, this is very important to me.😊

Quick start

  1. Clone this repository
git clone https://github.com/Runist/torch_Vision_Transformer
  1. Install torch_Vision_Transformer from source.
cd torch_Vision_Transformer
pip install -r requirements.txt
  1. Download the flower dataset.
wget https://github.com/Runist/image-classifier-keras/releases/download/v0.2/dataset.zip
unzip dataset.zip
  1. Modifying the config.py.
  2. Download pretrain weights, the url in utils.py.
  3. Start train your model.
python train.py
  1. Open tensorboard to watch loss, learning rate etc. You can also see training process and training process and validation prediction.
tensorboard --logdir ./summary/log

tensorboard.png

  1. Get prediction of model.
python predict.py

Train your dataset

You need to store your data set like this:

├── train
│   ├── daisy
│   ├── dandelion
│   ├── roses
│   ├── sunflowers
│   └── tulips
└── validation
    ├── daisy
    ├── dandelion
    ├── roses
    ├── sunflowers
    └── tulips

Reference

Appreciate the work from the following repositories:

License

Code and datasets are released for non-commercial and research purposes only. For commercial purposes, please contact the authors.