A clean and scalable template to kickstart your deep learning project (inspired heavily by lightning-hydra-template)
Click on Use this template to initialize new repository.
To launch composer with hydra simply run
python run.py
or for multi-gpu jobs composer -n <n> run.py
The default configs train a ResNet model on Cifar10. Switching to training Imagenet with ResNet50 and adding wandb logging is as simple as overwriting the defaults on the commandline.
composer -n <n> run.py model=resnet.yaml dataset=streaming_imagenet.yaml logger=wandb.yaml
.
|-- configs
| |-- algorithms <-- Algorithm configs
| |-- callbacks <-- Callback configs
| |-- dataset <-- Dataset configs
| |-- experiment <-- Experiment configs
| |-- logger <-- Logger configs
| |-- model <-- Model configs
| |-- optimizer <-- Optimizer configs
| |-- scheduler <-- Schedule configs
| └-- config.yaml <-- Main config file
|-- logs <-- Logs generated by Hydra
|-- notebooks <-- Scratch Notebooks
|-- src
| |-- models <-- model source code
| └-- train.py <-- Composer entry point
|-- tests <- Tests of any kind
|-- LICENSE
|-- README.md
|-- requirements.txt
└-- run.py
Hydra configs are works in a hierarchical fashion. The config that ultimately gets created results from config.yaml
defaults:
- _self_ # <-- configs defined in this yaml. They are overwritten by all following defaults
- model: cifar_resnet.yaml
- dataset: streaming_cifar10.yaml
- logger: null
- optimizer: default.yaml
- scheduler: default.yaml
- callbacks:
- lr_monitor.yaml
- speed_monitor.yaml
- checkpoint.yaml
- trainer: default.yaml
- algorithms: null
- experiment: null
seed: 42
name: hydra-test-run
Values in yamls loaded last have take precident over values loaded first. Because experiment is loaded last all things configured in configs/experiment/
are the ultimate source of truth. Values can also be overwritten/added on the commandline. Those values take ultimate precedent.
✨ProTip✨ to see the final composed config passed to the trainer without running the job add --cfg job
.
# Keys defined in the config can be overwritten on the commandline
python run.py experiment=resnet_mild trainer.max_duration=45
# Keys NOT defined in the config can be extended by adding `+`.
# This will cause an error if the key IS defined
python run.py experiment=resnet_mild +trainer.grad_clip_norm=1.5
# adding `++` will overwrite the key if defined or append the ky if not present
python run.py experiment=resnet_mild ++trainer.grad_clip_norm=1.5
Yahp also uses yaml and configures objects for train. The difference is that code doesn't have to be added to the composer/yahp registery to be used with composer. Lets compare the yahp and hydra configs to see the difference dding algorithms in the yaph based mild ResNet recipe.
Yahp:
algorithms:
blurpool: # <-- names are specific keys which need to match a yahp intializer the yahp registery
blur_first: true
min_channels: 16
replace_convs: true
replace_maxpools: true
channels_last: {}
ema:
half_life: 100ba
train_with_ema_weights: false
update_interval: 20ba
label_smoothing:
smoothing: 0.08
progressive_resizing:
delay_fraction: 0.4
finetune_fraction: 0.2
initial_scale: 0.5
mode: resize
resize_targets: false
size_increment: 4
Hydra:
algorithms:
blurpool: # <-- names don't need to match but can be referenced
_target_: composer.algorithms.BlurPool # <-- objects are initialled by the import target provided to the _target_:
blur_first: true # <-- any kwargs in composer.algorithms.BlurPool will can be filled
min_channels: 16
replace_convs: true
replace_maxpools: true
channels_last:
_target_: composer.algorithms.ChannelsLast
label_smoothing:
_target_: composer.algorithms.LabelSmoothing
smoothing: 0.1
ema:
_target_: composer.algorithms.EMA
half_life: 100ba
train_with_ema_weights: false
update_interval: 20ba
progressive_resizing:
_target_: composer.algorithms.ProgressiveResizing
delay_fraction: 0.4
finetune_fraction: 0.2
initial_scale: 0.5
mode: resize
resize_targets: false
size_increment: 4
# adapted from https://raw.githubusercontent.com/matthias-wright/cifar10-resnet/master/model.py
# under the MIT license
class ResNet9(nn.Module):
"""A 9-layer residual network, excluding BatchNorms and activation functions, as
described in this blog post: https://myrtle.ai/learn/how-to-train-your-
resnet-4-architecture/
Args:
num_classes: number of classes for the final classifier layer
residual_factory: a callable that returns a residual block;
defaults to the original ResNet9 residual block, but can be
used to specify a custom one
"""
def __init__(self, num_classes: int, residual_factory: Optional[Callable] = None):
super().__init__()
residual_factory = residual_factory or _ResidualBlock
self.body = nn.Sequential(
nn.Conv2d(in_channels=3,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=64, momentum=0.9),
nn.ReLU(),
nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=128, momentum=0.9),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
# residual_factory(in_channels=128,
_ResidualBlock(in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1),
nn.Conv2d(in_channels=128,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=256, momentum=0.9),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=256, momentum=0.9),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
residual_factory(in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)
def forward(self, x: torch.Tensor): # type: ignore
out = self.body(x)
out = out.reshape(-1, out.shape[1] * out.shape[2] * out.shape[3])
out = self.fc(out)
return out
Here we load an instance of our custom pytorch ResNet9 model and pass it as an argument to the ComposerClassifier constructor. That's it!
_target_: composer.models.ComposerClassifier # <-- composer model wrapper
module:
_target_: src.models.resnet9.ResNet9 # <-- local path to your code
num_classes: 10