Training an Image Classifier in PyTorch framework by using Transfer Learning with Pre-Trained CNN Model Architectures.
This repo consisting of 2 main parts:
- Jupyter Notebook that includes training, testing and inference.
- Command Line Application that could be used in training and prediction.
In this project, you'll train an image classifier to predict 102-class of flower species, and export the trained model, then using them for inference afterwards.
Once completing this project, you'll have a software that can be trained on any set of labelled images to predict any kind of images not just flower species which will be a powerful tool when integrating it with any kind of applications that require image prediction.
By the end of this project you'll deal with a user-friendly command line application that anyone could use without any previous requirements.
-
If you use Anaconda, you could create an environment with all required packages directly from
requirements.txt
by using the command:$ conda create --name <env> --file req.txt
-
If you don't, here are the required packages:
cudatoolkit==8.0
numpy==1.13.3
pandas==0.22.0
python==3.6.9
pytorch==0.4.0
torchvision==0.2.1
Image
First make sure you've the latest pip version by the command:
python -m pip install --upgrade pip
Then you could use pip to install the packages
pip install python==3.6.9 numpy==1.13.3
Hint: to install cunda use the command:
conda install cudatoolkit=8.0 -c pytorch
or
pip install cudatoolkit==8.0 -c pytorch
You could work on the flowers dataset or any other one but note that the dataset must be labeled and divided in folders where each folder is named by its class number.
Example
-
flowers\train\52\image_04221.jpg
- Image Name: image_04221.jpg
- Class Number: 52
- Dataset of: Training
-
flowers\valid\1\image_06756.jpg
- Image Name: image_06756.jpg
- Class Number: 1
- Dataset of: Validation
You'll also need to load in a mapping from category label (class number) to category name. You can find this in the file cat_to_name.json in case of using flowers dataset. It's a JSON object which you can read in with the json module. This will give you a dictionary mapping the integer encoded categories to the actual names of the flowers.
In order to avoid rendering problems you could check it out in nbviewer.
- Training
- Prediction
You would use train.py
file to train a new Deep Neural Network on a dataset of images and saves the model to a checkpoint.
-
Required Arguments
data_dir
---> directory path of the datasets.
-
Optional Arguments
-s
or--save_dir
---> directory path to save the Trained Model inside it. -- Default = work directory-a
or--arch
---> choosing a CNN Model Architecture. -- Default = vgg19-l
or--learning_rate
---> choosing a learning rate for DNN. -- Default = 0.0001--hidden_units
---> number of hidden units of hidden layers (Must be integers). -- Default = 1024-d
or--drop_prob
---> drop probability of the hidden units of the hidden layers. -- Default = 0.2-e
or--epochs
---> choosing number of model trainings (Must be integer). -- Default = 20-g
or--gpu
---> choosing GPU for training or inference.
-
Basic Usage
python train.py datasets_directory
-
Other Examples
python train.py datasets_directory -s checkpoints_directory --arch vgg16
python train.py datasets_directory -l 0.001 --hidden_units 2048 512
python train.py datasets_directory -d 0.1 -e 10 -g
-
Supported CNN Architectures
Architectures AlexNet VGG11 VGG13 VGG16 VGG19 ResNet18 ResNet34 ResNet50 ResNet101 ResNet152 DenseNet121 DenseNet161 DenseNet169 DenseNet201 -
Output
- While Training: Printing out current epoch, training loss, validation loss, and validation accuracy.
- Ex: Epoch: 8/8.. Training Loss: 0.599.. Validation Loss: 0.782.. Validation Accuracy: 0.809
- After Training: A checkpoint that contains the trained DNN wights, biases, and hyper parameters.
- Ex: resnet18.pth
- While Training: Printing out current epoch, training loss, validation loss, and validation accuracy.
You would use predict.py
file to predict the class of an image using the checkpoint of any saved model, and the probability of the topmost likely classes.
-
Required Arguments
input
---> path of the flower that you want to predict its label.checkpoint
---> path of the trained DNN Model.
-
Optional Arguments
-k
or--top_k
---> choosing top K most likely classes. -- Default = 1-c
or--category_names
---> choosing a mapping of categories to real names.-g
or--gpu
---> choosing GPU for training or inference.
-
Basic Usage
python predict.py input_image_path checkpoint
-
Other Examples
python predict.py input_image_path checkpoint -g --top_k 3
python predict.py input_image_path checkpoint --category_names cat_To_name.json
python predict.py input_image_path checkpoint -g -c cat_To_name.json -k 3
-
Output
Printing the most likely image class and it's associated probability.
- CS231n Convolutional Neural Networks for Visual Recognition - Transfer Learning
- PYTORCH TUTORIAL - TRANSFER LEARNING FOR COMPUTER VISION
- PYTORCH TUTORIAL - FINETUNING TORCHVISION MODELS
- Udacity - Frequently Asked Questions for Classifying Images Project
Inspired by Udacity AI Programming with Python Nanodegree.