Real vs Simulated Marine Image Classification

Overview

This project classifies real and simulated marine environment images using a deep learning approach. The workflow includes data collection, preprocessing, training a ResNet-18 model, and evaluating its performance. Key metrics such as accuracy, precision, recall, and F1 score are analyzed. Misclassified images are also logged for further insights.


Table of Contents


Dataset

  • Synthetic Images: The SimuShips Dataset, stored in the data/synthetic folder. Total 9471 images.
  • Real Images:
    • Marine Obstacle Detection Dataset (MODD): Used for training and validation. Stored in subfolders under data/real/{}/images. Dataset link. Total 4454 images.
    • MaSTr1325 Dataset: Used for testing only. Stored in data/test/real/MaSTr1325_images_512x384. Dataset link. Total 1325 images.
  • The data is split into training and validation sets with an 80-20 ratio.
  • Due to imbalance in the number of synthetic and real images, I chose to use Weighted Cross-Entropy Loss

Approach

Data Preprocessing

  1. Images are resized to 224x224 pixels while maintaining the original aspect ratio.
  2. Images are padded to ensure uniform dimensions.
  3. Normalization is applied using ImageNet statistics (mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]).
  4. Class labels:
    • 0: Synthetic images
    • 1: Real images

Model Architecture

  • Model: Pre-trained ResNet-18 from torchvision.

    • Pre-trained on the ImageNet-1k dataset. This dataset includes classes related to ships and boats, thus, making it a suitable base model for the required task.
      • Relevant ImageNet classes include container ships, speedboats, and ocean liners.
        • 510: 'container ship, containership, container vessel',
        • 628: 'liner, ocean liner',
        • 724: 'pirate, pirate ship',
        • 814: 'speedboat',
        • And others related to watercraft
      • Reference for ResNet-18: PyTorch ResNet-18 Documentation.
  • Modifications:

    • The final fully connected layer is replaced to classify between two classes: real vs synthetic.

Training

  • Loss Function: Weighted Cross-Entropy Loss to address class imbalance.
  • Optimizer: Adam with a learning rate of 0.001.
  • Early Stopping: Stops training if validation accuracy does not improve for a set number of epochs.
  • Configurable parameters:
    • --batch_size: Batch size (default: 32)
    • --epochs: Number of training epochs (default: 10)
    • --learning_rate: Learning rate for the optimizer (default: 0.001)

Evaluation

  • Metrics:
    • Accuracy
    • Precision
    • Recall
    • F1 Score
  • Logging:
    • Misclassified image paths are saved for further analysis.
    • Train and val related metrics are also stored in logs/train_{date}_{time}/training.txt during training.
    • All results are visualized using Weights & Biases (wandb).

Results

Ran a training experiment on 2024-12-02. The best saved model is models/best_model_20241202_021403.pth

  1. Final Metrics:

    • Metrics on both test and validation data:
      • Accuracy: 100%
      • Precision: 1.00
      • Recall: 1.00
      • F1 Score: 1.00
  2. Visualization:

Wandb plots

  • Training and validation loss trends.
  • Misclassified image logs in every epoch while training, are saved in logs/train_{date}_{time}/failed_val_epoch{}.txt.
  • Misclassified image logs during test, are saved in logs/failed_test_model_{model_name}.txt.

Usage

Prerequisites

  • Python 3.9
  • Docker or Conda for environment setup

Setup

Using Docker

  1. Build the Docker image:
    docker build -t real-vs-simulated .
  2. Run the container:
    docker run -it real-vs-simulated

Using Conda

  1. Create the environment:
    conda env create -f conda.yaml
  2. Activate the environment:
    conda activate dev

Running the Project

Training

Run the training script from the src folder:

python train.py --dataset_path ../data --batch_size 32 --epochs 10 --learning_rate 0.001

Evaluation

Run the evaluation script:

python test.py --dataset_path ../data/test --model_path ../models/best_model_20241202_021403.pth

Analysis and Visualization

Training Metrics

  • Used wandb for real-time monitoring and visualization of:
    • Loss
    • Accuracy
    • Precision
    • Recall
    • F1 Score

Misclassified Images

  • Paths of misclassified images are saved in logs/failed_test_model_<model_name>.txt.
  • Observation:
    • The model might overfit, potentially learning camera-specific intrinsics/extrinsics instead of generalized features. Better performance could be achieved with datasets ensuring similar camera properties for both real and simulated images.