/ScatterPlotCNN

This is the Picollage ML intern call question.

Primary LanguagePython

ScatterPlotCNN

This is the Picollage ML intern call question.

Files

  1. dataset.py: Dataset class used to load images from folder
  2. model.py: The ScatterPlotCNN model
  3. train.py: Training script
  4. inference.py: Inference script used after training

Usage

First, put the correlation_assignment into this directory.

Training

run python train.py to train the ScatterPlotCNN model

optional arguments:
  -h, --help            show this help message and exit
  --data_dir PATH TO THE DATASET
  --save_dir PATH TO SAVE THE MODEL
  --lr LEARNING RATE
  --batch_size BATCH SIZE
  --epochs NUMBER OF EPOCHS
  --subset

Inferencing

run python inference.py to train the ScatterPlotCNN model

optional arguments:
  -h, --help            show this help message and exit
  --img_path IMG PATH
  --model_path MODEL PATH