This project provides a flexible template for PyTorch-based machine learning experiments. It includes configuration management, logging with Weights & Biases (wandb), hyperparameter optimization with Optuna, and a modular structure for easy customization and experimentation.
config.py: Defines theRunConfigandOptimizeConfigclasses for managing experiment configurations and optimization settings.main.py: The entry point of the project, handling command-line arguments and experiment execution.model.py: Contains the model architecture (currently an MLP).util.py: Utility functions for data loading, device selection, training, and analysis.configs/run_template.yaml: Template for run configuration.configs/optimize_template.yaml: Template for optimization configuration.analyze.py: Script for analyzing completed runs and optimizations, utilizing functions fromutil.py.
-
Clone the repository:
git clone https://github.com/yourusername/pytorch_template.git cd pytorch_template -
Install the required packages:
# Use pip pip install torch wandb rich beaupy polars numpy optuna matplotlib scienceplots # Or Use uv with sync requirements.txt (recommended) uv pip sync requirements.txt # Or Use uv (fresh install) uv pip install -U torch wandb rich beaupy polars numpy optuna matplotlib scienceplots
-
(Optional) Set up a Weights & Biases account for experiment tracking.
-
Configure your experiment by modifying
configs/run_template.yamlor creating a new YAML file based on it. -
(Optional) Configure hyperparameter optimization by modifying
configs/optimize_template.yamlor creating a new YAML file based on it. -
Run the experiment:
python main.py --run_config path/to/run_config.yaml [--optimize_config path/to/optimize_config.yaml]
If
--optimize_configis provided, the script will perform hyperparameter optimization using Optuna. -
Analyze the results:
python analyze.py
project: Project name for wandb loggingdevice: Device to run on (e.g., 'cpu', 'cuda:0')net: Model class to useoptimizer: Optimizer classscheduler: Learning rate scheduler classepochs: Number of training epochsbatch_size: Batch size for trainingseeds: List of random seeds for multiple runsnet_config: Model-specific configurationoptimizer_config: Optimizer-specific configurationscheduler_config: Scheduler-specific configuration
study_name: Name of the optimization studytrials: Number of optimization trialsseed: Random seed for optimizationmetric: Metric to optimizedirection: Direction of optimization ('minimize' or 'maximize')sampler: Optuna sampler configurationpruner: (Optional) Pruner configurationsearch_space: Definition of the hyperparameter search space
-
Custom model: Modify or add models in
model.py. Models should accept ahparamsargument as a dictionary, with keys matching thenet_configparameters in the run configuration YAML file. -
Custom data: Modify the
load_datafunction inutil.py. The current example uses Cosine regression. Theload_datafunction should return train and validation datasets compatible with PyTorch's DataLoader. -
Custom training: Customize the
Trainerclass inutil.pyby modifyingstep,train_epoch,val_epoch, andtrainmethods to suit your task. Ensure thattrainreturnsval_lossor a custom metric for proper hyperparameter optimization.
- Configurable experiments using YAML files
- Integration with Weights & Biases for experiment tracking
- Hyperparameter optimization using Optuna
- Support for multiple random seeds
- Flexible model architecture (currently MLP)
- Device selection (CPU/CUDA)
- Learning rate scheduling
- Analysis tools for completed runs and optimizations
The analyze.py script utilizes functions from util.py to analyze completed runs and optimizations. Key functions include:
select_group: Select a run group for analysisselect_seed: Select a specific seed from a run groupselect_device: Choose a device for analysisload_model: Load a trained model and its configurationload_study: Load an Optuna studyload_best_model: Load the best model from an optimization study
These functions are defined in util.py and used within analyze.py.
To use the analysis tools:
-
Run the
analyze.pyscript:python analyze.py -
Follow the prompts to select the project, run group, and seed (if applicable).
-
The script will load the selected model and perform basic analysis, such as calculating the validation loss.
-
You can extend the
main()function inanalyze.pyto add custom analysis as needed, utilizing the utility functions fromutil.py.
Contributions are welcome! Please feel free to submit a Pull Request.
This project is provided as a template and is intended to be freely used, modified, and distributed. Users of this template are encouraged to choose a license that best suits their specific project needs.
For the template itself:
- You are free to use, modify, and distribute this template.
- No attribution is required, although it is appreciated.
- The template is provided "as is", without warranty of any kind.
When using this template for your own project, please remember to:
- Remove this license section or replace it with your chosen license.
- Ensure all dependencies and libraries used in your project comply with their respective licenses.
For more information on choosing a license, visit choosealicense.com.
PFL (Predicted Final Loss) Pruner
The PFL pruner is a custom pruner that helps optimize hyperparameter search by early stopping unpromising trials. It maintains top k trials based on validation loss and prunes trials if their predicted final loss is worse than the worst saved PFL.
- Maintains top k trials based on validation loss
- Predicts final loss using loss history
- Supports multiple random seeds
- Compatible with Optuna's pruning interface
In your optimize_template.yaml, configure the pruner under the pruner section:
pruner:
name: pruner.PFLPruner
kwargs:
n_startup_trials: 10 # Number of trials to run before pruning starts
n_warmup_epochs: 10 # Number of epochs to run before pruning can occur
top_k: 10 # Number of best trials to maintain
target_epoch: 50 # Target epoch for final loss predictionn_startup_trials: Number of trials to run before pruning startsn_warmup_epochs: Number of epochs to wait before pruning can occur within each trialtop_k: Number of best trials to maintain for comparisontarget_epoch: Target epoch number used for final loss prediction
- For the first
n_startup_trials, all trials run without pruning - Within each trial, no pruning occurs during the first
n_warmup_epochs - After warmup:
- The pruner maintains a list of top k trials based on validation loss
- For each trial, it predicts the final loss using the loss history
- If a trial's predicted final loss is worse than all saved trials, it is pruned