/flower-classification

Fine grained image classification with a Vision Transformer

Primary LanguageJupyter NotebookMIT LicenseMIT

Vision Transformer for Flower Classification

The dataset used in this project is a custom subset of the Oxford 102 Flowers dataset. load_data.ipynb loads the data using Tensorflow, size and split of the dataset can be defined in the notebook, images are then saved locally to be used in flower_classification.ipynb with torch.

The notebook flower_classification first preprocesses and explores the data. A Visual Transformer pretrained on ImageNet-21k (Google's ViT-small, 22M parameters) is finetuned on the dataset to achieve an accuracy of around 95% on the test dataset. Hyperparameters have been tuned based loosely on this paper.

SGD is used as an optimizer with a momentum of 0.9 and begins with a learning rate of 0.001. The learning rate is then decayed with a cosine schedule incorporating warm starts.