The objective of this project is to build a neural network from scratch to evaluate the MNIST dataset. The MNIST dataset is a collection of handwritten digits and is commonly used as a benchmark for image classification tasks in machine learning. By achieving high accuracy on this dataset, the model demonstrates its ability to perform well on similar classification tasks.
The project involves the following steps:
-
Dataset Loading and Preprocessing:
- Loading the MNIST dataset using
torchvision.datasets
. - Applying necessary transformations like converting images to tensors and normalizing them.
- Loading the MNIST dataset using
-
Exploratory Data Analysis:
- Visualizing a few samples from the dataset to understand its structure and the nature of the images.
-
Building the Neural Network:
- Designing a neural network architecture using
torch.nn
andtorch.nn.functional
. - Initializing the model, specifying the loss function, and defining the optimizer.
- Designing a neural network architecture using
-
Training the Model:
- Training the neural network on the training dataset.
- Validating the model on the validation dataset during training.
- Recording and plotting the training and validation loss over epochs.
-
Evaluating the Model:
- Testing the model on the test dataset to compute the final accuracy.
- Visualizing the model’s predictions against actual labels.
-
Improving the Model:
- Tweaking hyperparameters like learning rate and training the model again for better accuracy.
-
Saving the Model:
- Saving the trained model for future use.
-
Sanity Checks:
- Loading the saved model and ensuring it performs as expected.
- Visualizing predictions on random samples from the test dataset.
- Generating a confusion matrix to evaluate model performance across different classes.
- Python: The primary programming language used for the project.
- PyTorch: Used for building and training the neural network.
- Torchvision: Used for loading the MNIST dataset and applying transformations.
- Matplotlib: For plotting loss curves and visualizing images.
- Seaborn: For plotting the confusion matrix.
- Scikit-learn: For generating the confusion matrix.
To run the project, ensure you have the necessary libraries installed. You can install the required packages using:
conda install torch torchvision matplotlib seaborn scikit-learn
Then, execute the notebook or script to train the model and evaluate its performance on the MNIST dataset.
This project provides a comprehensive workflow for building, training, and evaluating a neural network on the MNIST dataset. The final model achieves a high accuracy, demonstrating its effectiveness in classifying handwritten digits. The project also includes steps to improve the model, save it, and perform sanity checks to ensure its reliability.