Convert a BrainPy model to process batched input by `jax.vmap`
Opened this issue · 2 comments
For a model written to process single input data, is it possible to convert the model to process batched input data simply by using jax.vmap
? Or do we have to re-write the model to process batched data?
The code section looks like this:
# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())
def step_run(i, x_single):
'''
Inputs:
x_single: [height, width]
'''
x = bm.where(bm.logical_and(cfg['stim_start_timepoint']<=i, i<cfg['stim_end_timepoint']), x_single, blank_img)
out = model.step_run(i, x) # [n_neuron]
return out
def loss_fun(x_single, y_single):
'''
Inputs:
x_single: [height, width]
y_single: [1]
'''
model.reset_state()
indices = np.arange(cfg['total_timepoint']) # sequence length
spike_out = bm.for_loop(functools.partial(step_run, x_single=x_single), indices) # [length, n_neuron]
frate_out = bm.sum(spike_out, axis=0) + 1.0e-6 # [n_neuron]
predicts = bm.log(frate_out / bm.sum(frate_out)).unsqueeze(0) # log-prababilities, [batch=1, n_neuron]
loss = bp.losses.nll_loss(-predicts, y_single) # scalar, Need to manually add a negative sign because BrainPy does not do so. scalar
acc = bm.mean(predicts.argmax(-1) == y_single) # scalar
return loss, acc
grad_f = jax.vmap(bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True))
@bm.jit
def train(x_batch, y_batch):
'''
Inputs:
x_batch: [batch, height, width]
y_batch: [batch, 1]
'''
train_vars = model.train_vars().unique()
grads, losses, acces = grad_f(x_batch, y_batch) # PyTree of gradients, [batch], [batch]
grads_mean = jax.tree_map(lambda x: bm.sum(x, axis=0), grads)
loss = losses.mean() # scalar
acc = acces.mean() # scalar
opt.update(grads_mean)
return loss, acc
It currently raises the following error:
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[50000] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
I found a previous issue (#206) mentioning this. Is it still not possible to use jax.vmap
with brainpy models?
Thanks for opening this great question. Actually, the object-oriented style in BrainPy does not support a general mapping transformation with vmap
and pmap
. But we can easily customize our mapping for a specific problem. Here i will give you an example.
The key of BrainPy's Variable
system is to find out all variables used in the objects and then transform this object into a function so that it can be compiled by JAX's functional transformations. Existing brainpy transformations like brainpy.math.jit
, brainpy.math.scan
have already hidden these processes. However, for a new transformation, users can also follow such two steps.
In your case, you want to vmap
the gradient function to get the batched gradients. So, all weights
can not be batched, all states
or variables
should be batched, and the outputs should also be batched. Therefore, we can customize this transformation as:
import jax
import brainpy.math as bm
from functools import wraps
def vmap_grad_fun(f, *inputs):
# Step 1: finding out all variables #
# --------------------------------- #
# evaluation without spending any actual FLOP computation
vars, _ = bm.eval_shape(f, *inputs)
# separate variables into two groups: weights and states
weights, states = vars.separate_by_instance(bm.TrainVar)
# Step 2: transform the object as the function that compatible with jax.vmap #
# -------------------------------------------------------------------------- #
@wraps(f)
def new_fun(ws, vars, inputs):
# A. assign weights and states in each batch to the model
for key in ws: weights[key] = ws[key]
for key in vars: states[key] = vars[key]
# B. run the function
outputs = f(*inputs)
# C. return outputs of each batch
return outputs
ori_weights, ori_states = weights.dict_data(), vars.dict_data()
# replicate the states for batching
batch_size = inputs[0].shape[0]
batched_states = jax.tree_map(lambda x: bm.repeat(bm.expand_dims(x, 0), batch_size, axis=0), ori_states)
# batching the states and inputs
batched_outs = jax.vmap(new_fun, in_axes=(None, 0, 0), out_axes=0)(ori_weights, batched_states, inputs)
del batched_states
# recovery the origin weights and states
for key in ori_weights: weights[key] = ori_weights[key]
for key in ori_states: vars[key] = ori_states[key]
# Step 3: return the batched outputs
return batched_outs
I hope this example can help you achieve the desired transformation.