brainpy/BrainPy

How to implement gradient accumulation in BrainPy

CloudyDory opened this issue · 5 comments

When training a large model, it is easy to get out-fo-memory error even when the batch size is small. In PyTorch we can overcome this issue by gradient accumulation, that is, we split the batch into mini-batches, compute the loss for each mini-batch, add the loss together, and finally perform gradient descent.

I find the following training example in BrainPy's documentation, and I have two questions on modifying it to perform gradient accumulation:

  1. Is it correct to gather the grads output over several calls to the grad_fun (probably using bm.for_loop since we need to jit the function?), take the average of the contents of grads, and then run opt.update(grads) once?
  2. The loss_fun() reconstructs the bp.DSTrainer() and recompiles the model every time we call it. This seems rather inefficient. Is it possible to avoid recompiling the model during training?
# define the model
model = ANNModel(28, 100, 10)

# define the loss function
def loss_fun(inputs, targets):
  runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
  predicts = runner.predict(inputs, reset_state=True)
  predicts = bm.max(predicts, axis=1)
  loss = bp.losses.cross_entropy_loss(predicts, targets)
  acc = bm.mean(predicts.argmax(-1) == targets)
  return loss, acc

# define the gradient function which computes the gradients of the trainable weights
grad_fun = bm.grad(loss_fun,
                   grad_vars=model.train_vars().unique(),
                   has_aux=True,
                   return_value=True)

# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())

# training function
@bm.jit
def train(xs, ys):
  grads, loss, acc = grad_fun(xs, ys)
  opt.update(grads)
  return loss, acc

Yes, it is easy to implement the accumulation behavior like this:

import jax
import numpy as np

import brainpy as bp
import brainpy.math as bm

# define the model
model = ANNModel(28, 100, 10)
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())


def step_run(i, inp):
  bp.share.save(i=i, t=i * bm.get_dt())
  out = model(inp)
  return out


# define the loss function
def loss_fun(inputs, targets):
  model.reset()
  indices = np.arange(inputs.shape[0])  # sequence length
  predicts = bm.for_loop(step_run, (indices, inputs))

  predicts = bm.max(predicts, axis=1)
  loss = bp.losses.cross_entropy_loss(predicts, targets)
  acc = bm.mean(predicts.argmax(-1) == targets)
  return loss, acc


def grad_fun(last_grad, input_target):
  inputs, targets = input_target
  grad_f = bm.grad(loss_fun,
                   grad_vars=model.train_vars().unique(),
                   has_aux=True,
                   return_value=True)
  grads, loss, acc = grad_f(inputs, targets)
  new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients
  return new_grad, (loss, acc)


# training function
@bm.jit
def train(xs, ys):
  # xs: [N_mini_batch, N_time_steps, N_batch, features]
  grads = jax.tree_map(bm.zeros_like, model.train_vars().unique())
  grads, (losses, acces) = bm.scan(grad_fun, grads, (xs, ys))
  loss = losses.mean()
  acc = acces.mean()
  opt.update(grads)
  return loss, acc

This is just a prototype. I have not really run it. The key is to use brainpy.math.scan, similar to jax.lax.scan but can be applied to the brainpy's Variable system.

The loss_fun() reconstructs the bp.DSTrainer() and recompiles the model every time we call it. This seems rather inefficient. Is it possible to avoid recompiling the model during training?

Actually, in this case, brainpy.DSTrainer does not recompile, this is because the most outer function is train() which has been jitted.

Thank you very much for the suggestion! I will see if the code works.

The code works after taking the fixes in #602. Besides, we can also take the average of the gradients (instead of the sum of gradients) to make the training process similar to normal batch gradient descent, but simply summing the gradients also seems to work well (at least in my cases).

Thanks for sharing the training experience.