Pytorch project template for experiments and application (Description ...)
First, install dependencies
# Clone project
git clone https://github.com/ByeongGil-Jung/PytorchLightning-Project-Template.git
# Install project
cd PytorchLightning-Project-Template
pip install -e .
pip install -r requirements.txt
If you want to modify configurations, move below directory and modify it.
# Move configuration directory
cd config/hyperparameters
Next, navigate to any file and run it.
# Run module (example: mnist as your main contribution)
python main.py --model "fc" --data "mnist" --stage "fit" --tqdm_env "script"
This project is setup as a package which means you can now easily import any file into any other file like so:
from config.factory import HyperparameterFactory
from dataset.factory import DataModuleFactory
from domain.metadata import ModelMetadata
from model.factory import ModelFactory
from trainer.base import TrainerBase
# Arguments
model_name = "fc"
data_name = "mnist"
stage = "fit"
tqdm_env = "script"
model_metadata = ModelMetadata(model_name=model_name, information=None)
# Arguments controller
hyperparameter_factory = HyperparameterFactory.create(data_name=data_name, model_name=model_name)
datamodule_params = hyperparameter_factory.datamodule_params.to_dict()
trainer_params = hyperparameter_factory.trainer_params.to_dict()
model_params = hyperparameter_factory.model_params.to_dict()
# DataModule controller
datamodule = DataModuleFactory.create(data_name=data_name)
datamodule = datamodule(**datamodule_params)
datamodule.prepare_data()
datamodule.setup(stage=stage)
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
test_loader = datamodule.test_dataloader()
# Trainer controller
trainer = TrainerBase(model_metadata=model_metadata, **trainer_params)
# Model controller
model = ModelFactory.create(model_name=model_name, model_params=model_params)
# Find the optimal learning rate
if trainer.is_auto_lr_find:
trainer.lr_find(model=model, train_loader=train_loader, val_loader=val_loader)
# Training & Validation
if stage == "fit" or stage == "whole":
trainer.fit(model=model, train_dataloader=train_loader, val_dataloaders=val_loader)
# Testing
if stage == "test" or stage == "whole":
trainer.test(model=model, test_dataloaders=test_loader)
# Save figures
if trainer.is_auto_lr_find:
fig = trainer.lr_finder.plot(suggest=True)
fig.savefig(fname=model_metadata.model_file_metadata.optimal_lr_plot_path)
@article{Byeonggil Jung,
title={Title},
author={Team},
journal={Location},
year={Year}
}