Mobile-Gallery-Image-Classification-in-PyTorch
Multi-Class Image Classification on Mobile Gallery Images using Transfer Learning in PyTorch.
Introduction
Using the images present in your mobile gallery to train an Image Classifier using Transfer-Learning ! :D
STEP 1: Building a Custom Dataset
Dataset that I have used is https://www.kaggle.com/n0obcoder/mobile-gallery-image-classification-data It has 6 classes -
- Cars
- Memes
- Mountains
- Selfies
- Trees
- Whataspp_Screenshots
A few of the sample images form the training set are shown below
STEP 2: Data Pre-Processing and Making DataLoaders
Following are the transforms (ordered) applied to the images while training and testing-
- Resizing to (224, 224)
- Random Horizontal Flips (Only applied during the training phase)
- ToTensor (to convert the images into tensors)
- Normalization (using the ImageNet stats)
STEP 3: Defining a Suitable Model and Making the Necessary Tweaks
Architecture : Resnet34 I have replaced the last linear layer of the resnet34 with another linear layer which has 6 neurons present in it (6 is the number of classes present in the Mobile Gallery Image Dataset mentioned above in STEP 1). c
STEP 4: Transfer Learning by Freezing and Un-Freezing the Layers
Used pretrained weights of the selected architecture
We freeze the pretrained filter in the early and middle layers and train only the filters in the deep layers.
STEP 5: Loss Function and Optimizer
Loss Function: Cross Entropy Opimizer : Adam
STEP 6: Training and Validation
-
Trained 'layer 4' and 'fc' for 5 epochs.
-
Then trained only 'fc' for 3 more epochs
STEP 7: It's Testing Time !
-
Test Image
-
This Neural Network thinks that the given image belongs to >>> Memes <<< class with confidence of 95.21% Output