This project provides an end-to-end pipeline for training an image classifier, running inference on images, and converting the trained model into a format suitable for mobile devices (Android) using TFLite.
The pipeline supports two main operations: training a model and running inference. Additionally, it includes a utility to convert the trained PyTorch model to TFLite for use in mobile devices.
- Python 3.10
- Other dependencies can be installed via the
requirements.txtfile.
pip install -r requirements.txtYou need a dataset for training your image classifier. The images should be stored in directories corresponding to their class labels. For example:
dataset/
├── reddit/
│ ├── image1.jpg
│ ├── image2.jpg
├── twitter/
│ ├── image1.jpg
│ ├── image2.jpg
└── other/
├── image1.jpg
└── image2.jpg
Make sure the dataset directory follows this structure, where each subdirectory corresponds to a specific class.
This pipeline supports two pre-trained models:
- ResNet18
- MobileNetV3-Large
Both models are fine-tuned for your specific dataset by replacing the final fully connected layer to match the number of classes in the dataset. During training, the model weights are updated to fit the provided dataset.
The class_names.txt file should contain the class names, one per line. The order of class names is important, as it determines the model's output labels.
Example class_names.txt:
reddit
twitter
other
You can run the pipeline in three modes: train, inference, and convert. Use the --mode argument to select the mode.
To train a new model or continue training an existing one, use the following command:
python main.py --mode train --dataset /path/to/your/dataset --epochs 100 --batch_size 8 --lr 0.001 --checkpoint checkpoint/screenshot_model --class_names /path/to/class_names.txt --model_type resnet18--mode train: Specifies that you want to train the model.--dataset: Path to the directory containing your dataset.--class_names: Path to the text file containing class names.--epochs: Number of epochs to train the model (optional).--batch_size: The batch size used during training (optional).--lr: The learning rate for the optimizer (optional).--checkpoint: Filename (without ext) to save the trained model's checkpoint (optional).--plot_file: Path to the file where the loss and accuracy plot will be saved (optional).--model_type: Selectresnet18(default) ormobilenetv3to train a different model.
python main.py --mode train --dataset /path/to/your/dataset --epochs 100 --batch_size 8 --lr 0.001 --checkpoint checkpoint/mobilenet_model --class_names /path/to/class_names.txt --model_type mobilenetv3- The model will be trained, and the loss and accuracy plots will be saved in
results/loss_accuracy_plot.png. - The trained model will be saved to the checkpoint file (
checkpoint/screenshot_model).
To use a trained model to classify an image, use the following command:
python main.py --mode inference --model /path/to/model_checkpoint --input /path/to/image.jpg --class_names /path/to/class_names.txt --model_type resnet18--mode inference: Specifies that you want to run inference.--model: Path to the model checkpoint file (trained model).--input: Path to the image file you want to classify.--class_names: Path to the text file containing class names.--model_type: Selectresnet18(default) ormobilenetv3.
python main.py --mode inference --model /path/to/mobilenet_checkpoint --input /path/to/image.jpg --class_names /path/to/class_names.txt --model_type mobilenetv3- The predicted class for the input image will be printed on the console.
To convert the trained PyTorch model to TFLite format for deployment on mobile devices, use the convert.py script.
python convert.py --model checkpoint/screenshot_best.pt --output models/screenshot.tflite --num_classes 3 --model_type resnetpython convert.py --model checkpoint/screenshot_mobile_best.pt --output models/screenshot_mobile.tflite --num_classes 3 --model_type mobilenet- A TFLite model file will be generated and saved at the specified path (
models/screenshot_v3.tflite).
This pipeline supports two models:
- ResNet18: A deep CNN model with residual connections, commonly used for image classification.
- MobileNetV3-Large: A lightweight, efficient model optimized for mobile devices.
Both models are pre-trained on ImageNet and fine-tuned by replacing the last fully connected layer to match the number of output classes in the dataset.
- Optimizer: The model is trained using the Adam optimizer.
- Loss Function: The CrossEntropyLoss criterion is used for classification.
- Learning Rate Scheduler:
ReduceLROnPlateaudynamically adjusts the learning rate when the validation loss plateaus. - Early Stopping: Training stops if the validation loss does not improve for a specified number of epochs (
patience).
The dataset should consist of images organized into class directories. Each subdirectory represents a class, and the images inside each subdirectory belong to that class.
Class names should be provided in a text file (class_names.txt) where each line corresponds to a class name. The order of the class names in the file should match the order used during model training.