This is the Picollage ML intern call question.
dataset.py
: Dataset class used to load images from foldermodel.py
: The ScatterPlotCNN modeltrain.py
: Training scriptinference.py
: Inference script used after training
First, put the correlation_assignment
into this directory.
run python train.py
to train the ScatterPlotCNN model
optional arguments:
-h, --help show this help message and exit
--data_dir PATH TO THE DATASET
--save_dir PATH TO SAVE THE MODEL
--lr LEARNING RATE
--batch_size BATCH SIZE
--epochs NUMBER OF EPOCHS
--subset
run python inference.py
to train the ScatterPlotCNN model
optional arguments:
-h, --help show this help message and exit
--img_path IMG PATH
--model_path MODEL PATH