./initialize
This is a starter template for machine learning projects in PyTorch.
The core of this library lives over here.
Train a ResNet18 model on CIFAR10:
runml train configs/image_demo.yaml
Train an RL PPO model on BipedalWalker:
runml train configs/rl_demo.yaml
Launch a Slurm job (requires setting the SLURM_PARTITION
environment variable):
runml launch configs/image_demo.yaml launcher.name=slurm launcher.num_nodes=1 launcher.gpus_per_node=1
A new project is broken down into five parts:
- Task: Defines the dataset and calls the model on a sample. Similar to a LightningModule.
- Model: Just a PyTorch
nn.Module
- Trainer: Defines the main training loop, and optionally how to distribute training when using multiple GPUs
- Optimizer: Just a PyTorch
optim.Optimizer
- LR Scheduler: Just a PyTorch
optim.LRScheduler
Most projects should just have to implement the Task and Model, and use a default trainer, optimizer and learning rate scheduler. Running the training command above will log the location of each component.
New tasks, models, trainers, optimizers and learning rate schedulers are added using the same API, although each should implement different things. For example, to create a new model, make a new file under ml/models
and add the following code:
from dataclasses import dataclass
from ml.core.config import conf_field
from ml.core.registry import register_model
from ml.models.base import BaseModel, BaseModelConfig
@dataclass
class NewModelConfig(BaseModelConfig):
some_param: int = conf_field(10)
@register_model("new_model", NewModelConfig)
class NewModel(BaseModel[NewModelConfig]):
def forward(self, x):
return x + self.config.some_param
The framework will automatically search in all of the files in ml/models
to populate the model registry. In your config file, you can then reference the registered model using whatever key you chose:
model:
name: new_model
Similar APIs exist for tasks, trainers, optimizers and learning rate schedulers. Try running the demo config to get a sense for how each of these fit together.
This repository implements some features which I find useful when starting ML projects.
This template makes it easy to add custom C++ extensions to your PyTorch project. The demo includes a custom TorchScript-compatible nucleus sampling function, although more complex extensions are possible.
This template automatically runs black
, isort
, pylint
and mypy
against your repository as a Github action. You can enable push-blocking until these tests pass.
The training loop is pretty well optimized, but sometimes you can do stupid things when implementing a task that impact your performance. This adds a lot of timers which make it easy to spot likely training slowdowns, or you can run the full profiler if you want a more detailed breakdown.
By default, models are run using torch.compile
. To disable this behavior and use eager mode execution, set TORCH_COMPILE=0
. If you try to launch a Slurm job with this flag set, it will show a warning.