Modula is a deep learning framework designed for graceful scaling. The user defines a compound module (i.e. neural network) in Modula by arbitrarily composing atom and bond modules (e.g. linear layers and nonlinearities). Modula then automatically normalizes weight updates in the modular norm corresponding to this compound. This leads to automatic learning rate transfer across width, depth and possibly other architectural dimensions. Modula is built on top of PyTorch.
Modula is an experimental framework based on our research paper: Scalable Optimization in the Modular Norm. Use at your own risk.
Install modula via pip:
pip install modula
Next, let's download the Shakespeare data:
pip install datasets
python examples/data/shakespeare.py
And finally, let's train a GPT:
python examples/train-gpt.py
This runs on CPU and should get train loss: 1.65 and test loss: 1.80 after 2000 iterations.
The following figures are all for 10k steps of training GPT on OpenWebText.
The left two panels of this first figure show that the learning rate of modular-normalized-Adam transfers across width and depth. The right two panels show scaling performance when learning rate is tuned at the scale marked by the gray dotted line. Modular normalization fixes width and depth scaling for both SGD and Adam.
Our GPT implementation differs slightly from standard nanoGPT since we designed it to scale better. To check our implementation is not losing performance compared to standard nanoGPT, in the next figure we compare three setups:
- our direct reimplementation of nanoGPT with Adam (column 1);
- our custom GPT implementation with Adam and without modular normalization (column 2);
- our custom GPT implementation with Adam and with modular normalization (column 3).
Notice that our GPT implementation transfers learning rate better than nanoGPT, even without modular normalization (column 2 versus column 1). We also noticed other interesting behaviours: for example, our GPT implementation with modular normalization transfers learning rate quite well across context length:
Let's start by building an MLP and initializing its weights:
from modula.atom import Linear
from modula.bond import ReLU
mlp = Linear(10,10000) @ ReLU() @ Linear(10000, 1000)
weights = mlp.initialize(device="cpu")
Now let's fit this MLP to some random data:
from torch import randn, no_grad
data, target = randn(1000), randn(10)
for step in range(steps:=20):
output = mlp(data, weights)
loss = (target - output).square().mean()
loss.backward()
with no_grad():
mlp.normalize(grad := weights.grad()) # normalize the gradient in the modular norm
weights -= 0.1 * grad
weights.zero_grad()
mlp.regularize(weights, strength = 0.01) # regularize the weight vector
print(step, loss.item())
Modula provides two useful abstractions: the Vector
class and the Module
class.
The Vector
class is used to store the weights of the neural net. For instance, in the previous example the line weights = mlp.initialize(device="cpu")
creates a Vector
called weights
. And grad = weights.grad()
stores the gradient of weights
as a Vector
called grad
. The point of all this as that you can do operations on Vector
objects like:
weights -= 0.1 * grad
This allows you to write optimization algorithms without doing for loops over lists of tensors everywhere.
The meat of Modula is found in the Module
class. A Module
m
must have six attributes. Two numbers:
m.mass: float # sets the proportion of feature learning m contributes to any supermodule
m.sensitivity: float # estimates the sensitivity of m to input perturbations
and four methods:
m.forward(x: Tensor, w: Vector) -> Tensor # maps an input and a weight vector to an output
m.initialize() -> Vector # randomly samples a weight vector
m.normalize(w: Vector) # scales vector w to have unit modular norm
m.regularize(w: Vector, strength: float) # regularizes vector w in-place
There are three kinds of modules in Modula:
- Atoms are modules that have weights and where the attributes are hand-declared, e.g.
modula.atom.Linear
; - Bonds are modules without weights and where the attributes are hand-declared, e.g.
modula.bond.GELU
; - Compounds are modules built by combining atoms and bonds---their attributes are inferred automatically, e.g.
modula.compound.GPT
.
We provide the following basic operations for building compounds:
M_2 @ M_1 # composes module M_2 with module M_1
(M_1, M_2) # acts as a tuple module in any further composition
M_1 + M_2 # returns the module sum
a * M # multiplies module M by scalar a
M ** L # returns the Lth iterate of module M, i.e. M @ M @ ... @ M
So, for example, the following residualize
function takes a block module block
and a depth L
and returns a resnet with this block:
from modula.bond import Identity
residualize = lambda block, L : ((1 - 1/L) * Identity() + 1/L * block) ** L
The point of all this is that you can build a complicated compound module m
, and all module attributes will be automatically inferred. Then during training, you can call m.normalize
on the Adam or SGD updates, and the learning rate will automatically transfer when scaling the architecture.
.
├── assets # figures, logos and such
└── ...
├── examples
│ ├── hello-world.py # simple training loop
│ ├── gradient-accumulation.py # gradient accumulation for large batch training
│ └── multi-gpu.py # multi GPU training with torch.distributed
├── modula
│ ├── __init__.py
│ ├── abstract.py # basic definitions: composition & concatenation, etc.
│ ├── atom.py # modules with weights: linear, conv2d etc.
│ ├── bond.py # modules without weights: ReLU, FunctionalAttention, etc.
│ ├── compound.py # derived modules: GPT, ResNet, etc.
│ └── vector.py # class for storing weight vectors
├── paper # code associated with the arXiv paper
└── ...
├── LICENSE # MIT license
├── README.md # this file
└── setup.py # pip package stuff
If Modula is useful in your research, consider citing our paper:
@article{modula,
author = {Tim Large and Yang Liu and Minyoung Huh and Hyojin Bahng and Phillip Isola and Jeremy Bernstein},
title = {Scalable Optimization in the Modular Norm},
journal = {arXiv:2405.14813},
year = 2024
}
The design of Modula was influenced by μP, autobound, AGD and PyTorch itself.
Modula is released under an MIT license.