Gradient accumulation generates JAX leaking error in BrainPy 2.5.0
Closed this issue · 3 comments
After upgrading to BrainPy 2.5.0, I found that training by gradient accumulation does not work in the newest version.
We can use logistic regression as an example:
import numpy as np
import jax
import brainpy as bp
import brainpy.math as bm
bm.clear_buffer_memory()
bm.set(float_=bm.float32)
bm.set_platform('cpu')
#%% Network definition
class Network(bp.DynSysGroup):
def __init__(self):
super().__init__()
self.weight = bm.TrainVar(bm.random.randn(2))
def update(self, data):
out = bm.sum(self.weight * data)
return out
#%% Create network and fake data
print('Creating network... ')
with bm.training_environment():
model = Network()
optimizer = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())
print('Creating data... ')
train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1,-1]]),
np.random.randn(100, 2) + np.array([[ 1, 1]])], axis=0) # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32),
bm.ones(100, dtype=bm.int32)], axis=0) # [batch]
#%% Training functions
def loss_fun(x_single, y_single):
'''
Inputs:
x_single: [feature]
y_single: scalar
'''
predict = model.update(x_single) # scalar
loss = bp.losses.binary_logistic_loss(predict, y_single) # scalar
acc = bm.mean(bm.int32(predict >= 0.0) == y_single) # scalar
return loss, acc
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
def grad_fun(last_grad, x_y_single):
'''
Inputs:
last_grad: PyTree of gradients of each trainable parameter.
x_y_single: tuple of ([feature], scalar), a single training sample.
'''
x_single, y_single = x_y_single # [feature], scalar
grads, loss, acc = grad_f(x_single, y_single) # PyTree of gradients, scalar, scalar
new_grad = jax.tree_map(lambda x, y: bm.TrainVar(bm.add(x, y)), last_grad, grads, is_leaf=bm.is_bp_array) # accumulate gradients
return new_grad, (loss, acc)
@bm.jit
def train(x_batch, y_batch):
'''
Inputs:
x_batch: [batch, feature]
y_batch: [batch]
'''
train_vars = model.train_vars().unique()
# Gradient accumulation
grads = jax.tree_map(bm.zeros_like, train_vars)
grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch)) # PyTree of gradients, [batch], [batch]
optimizer.update(grads)
loss = losses.mean() # scalar
acc = acces.mean() # scalar
return loss, acc
#%% Start training
print('Start training...')
train_epochs = 15
train_loss = bm.zeros(train_epochs, dtype=bm.float_)
train_acc = bm.zeros(train_epochs, dtype=bm.float_)
for e in range(train_epochs):
train_loss[e], train_acc[e] = train(train_data, train_label)
print("Epoch {}, train_loss={:.3f}, train_acc={:.2f}%".format(e, train_loss[e], train_acc[e]*100.0))
print('Done!')
On BrainPy 2.4.6.post5, the above code trains normally. But on BrainPy 2.5.0, the above code generates the following error:
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was fun2scan at /home/xxx/miniconda3/envs/brainpy2.5/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py:929 traced for scan.
------------------------------
The leaked intermediate value was created on line /home/xxx/project/test_train_bug.py:62:53 (grad_fun.<locals>.<lambda>).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/tmp/ipykernel_571020/323626455.py:1 (<module>)
/home/xxx/project/test_train_bug.py:90:34 (<module>)
/home/xxx/project/test_train_bug.py:76:29 (train)
/home/xxx/project/test_train_bug.py:62:15 (grad_fun)
/home/xxx/project/test_train_bug.py:62:53 (grad_fun.<locals>.<lambda>)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
Environment (BrainPy 2.5.0):
- Ubuntu 22.04
- Python 3.11
- brainpy 2.5.0
- brainpylib 0.2.6 (with cuda12 support)
- jax and jaxlib 0.4.24 (with cuda support)
- taichi 1.7.0
Environment (BrainPy 2.4.6.post5):
- Ubuntu 22.04
- Python 3.11
- brainpy 2.4.6.post5
- brainpylib 0.2.4 (with cuda12 support)
- jax and jaxlib 0.4.23 (with cuda support)
- taichi 1.7.0
Thanks for the report!
The problem can be fixed by changing the line
new_grad = jax.tree_map(lambda x, y: bm.TrainVar(bm.add(x, y)), last_grad, grads, is_leaf=bm.is_bp_array) # accumulate gradients
into
new_grad = jax.tree_map(bm.add, last_grad, grads) # accumulate gradients
Please let me know whether the changes fix the error.
Thank you very much for the reply, it fixes the error. Could you briefly explain why does it happen?
The error caused here is somehow not intuitive. This involves the issue of understanding the variable tracing in BrainPy. I do not encourage you to understand this error. 😂😂