In this notebook, I will train a Convolutional Neural Network (CNN) to classify images from the CIFAR-10 database. The CIFAR-10 dataset consists of small color images grouped into ten classes, including objects like airplanes, automobiles, birds, cats, and more.
The process will be broken down into the following steps:
- Importing essential libraries including PyTorch, NumPy, and Matplotlib for data manipulation, machine learning tasks, and visualization.
- Utilizing PyTorch to download and preprocess the CIFAR-10 dataset.
- Splitting the dataset into training and testing sets, with a portion allocated for validation.
- Creating data loaders for efficient loading of batches during training, validation, and testing.
- Visualizing a batch of training images to gain insights into the nature of the data.
- Displaying 20 images at a time, each labeled with its respective class.
- Selecting and visualizing a single image in detail by extracting the RGB channels.
- Creating subplots for each channel and annotating pixel values to understand their contribution to the overall color.
- Defining a CNN architecture named "Net" for image classification.
- The architecture includes convolutional layers, max-pooling layers, fully connected layers, and dropout layers for regularization.
- Loading necessary libraries and specifying categorical cross-entropy as the loss function.
- Checking for GPU availability and moving the model to GPU if possible.
- Choosing stochastic gradient descent (SGD) as the optimizer with a learning rate of 0.01.
- Training the CNN on the CIFAR-10 dataset for 35 epochs.
- Monitoring and printing training and validation losses after each epoch.
- Saving the model state if the validation loss decreases, ensuring the best-performing model is retained.
- Loading the model state from the file 'model_cifar.pt,' representing the model with the lowest validation loss.
- Evaluating the performance of the trained CNN on the test set.
- Calculating and printing the average test loss, test accuracy for each class, and overall test accuracy.
- Visualizing a batch of test images along with their predicted and true labels.
- Displaying images in a grid with color-coded titles (green for correct predictions, red for incorrect), providing a quick overview of the model's performance.