How to implement gradient accumulation in BrainPy
CloudyDory opened this issue · 5 comments
When training a large model, it is easy to get out-fo-memory error even when the batch size is small. In PyTorch we can overcome this issue by gradient accumulation, that is, we split the batch into mini-batches, compute the loss for each mini-batch, add the loss together, and finally perform gradient descent.
I find the following training example in BrainPy's documentation, and I have two questions on modifying it to perform gradient accumulation:
- Is it correct to gather the
grads
output over several calls to thegrad_fun
(probably usingbm.for_loop
since we need to jit the function?), take the average of the contents ofgrads
, and then runopt.update(grads)
once? - The
loss_fun()
reconstructs thebp.DSTrainer()
and recompiles the model every time we call it. This seems rather inefficient. Is it possible to avoid recompiling the model during training?
# define the model
model = ANNModel(28, 100, 10)
# define the loss function
def loss_fun(inputs, targets):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs, reset_state=True)
predicts = bm.max(predicts, axis=1)
loss = bp.losses.cross_entropy_loss(predicts, targets)
acc = bm.mean(predicts.argmax(-1) == targets)
return loss, acc
# define the gradient function which computes the gradients of the trainable weights
grad_fun = bm.grad(loss_fun,
grad_vars=model.train_vars().unique(),
has_aux=True,
return_value=True)
# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())
# training function
@bm.jit
def train(xs, ys):
grads, loss, acc = grad_fun(xs, ys)
opt.update(grads)
return loss, acc
Yes, it is easy to implement the accumulation behavior like this:
import jax
import numpy as np
import brainpy as bp
import brainpy.math as bm
# define the model
model = ANNModel(28, 100, 10)
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())
def step_run(i, inp):
bp.share.save(i=i, t=i * bm.get_dt())
out = model(inp)
return out
# define the loss function
def loss_fun(inputs, targets):
model.reset()
indices = np.arange(inputs.shape[0]) # sequence length
predicts = bm.for_loop(step_run, (indices, inputs))
predicts = bm.max(predicts, axis=1)
loss = bp.losses.cross_entropy_loss(predicts, targets)
acc = bm.mean(predicts.argmax(-1) == targets)
return loss, acc
def grad_fun(last_grad, input_target):
inputs, targets = input_target
grad_f = bm.grad(loss_fun,
grad_vars=model.train_vars().unique(),
has_aux=True,
return_value=True)
grads, loss, acc = grad_f(inputs, targets)
new_grad = jax.tree_map(bm.add, last_grad, grads) # accumulate gradients
return new_grad, (loss, acc)
# training function
@bm.jit
def train(xs, ys):
# xs: [N_mini_batch, N_time_steps, N_batch, features]
grads = jax.tree_map(bm.zeros_like, model.train_vars().unique())
grads, (losses, acces) = bm.scan(grad_fun, grads, (xs, ys))
loss = losses.mean()
acc = acces.mean()
opt.update(grads)
return loss, acc
This is just a prototype. I have not really run it. The key is to use brainpy.math.scan
, similar to jax.lax.scan
but can be applied to the brainpy's Variable system.
The loss_fun() reconstructs the bp.DSTrainer() and recompiles the model every time we call it. This seems rather inefficient. Is it possible to avoid recompiling the model during training?
Actually, in this case, brainpy.DSTrainer
does not recompile, this is because the most outer function is train()
which has been jitted.
Thank you very much for the suggestion! I will see if the code works.
The code works after taking the fixes in #602. Besides, we can also take the average of the gradients (instead of the sum of gradients) to make the training process similar to normal batch gradient descent, but simply summing the gradients also seems to work well (at least in my cases).
Thanks for sharing the training experience.