Questions about backpropagation through delay variables
CloudyDory opened this issue · 6 comments
In #626 we mentioned that the rotation method in delay variables does not implement an autograd functionality. However I have tested this in training and found that the parameters can be trained normally. Is there a misunderstanding on the issue?
import functools
import numpy as np
import jax
import brainpy as bp
import brainpy.math as bm
snn_latency = 20
dt = 1.0
bm.clear_buffer_memory()
bm.set(float_=bm.float32)
bm.set_platform('cpu')
bm.set_dt(dt)
#%% Network definition
class Network(bp.DynSysGroup):
def __init__(self):
super().__init__()
self.neu = bp.dyn.Lif(size=2, V_rest=0.0, V_reset=0.0, V_th=1.0, spk_fun=bm.surrogate.Arctan())
self.delay_len = 2
self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method='rotation')
self.weight = bm.TrainVar(bm.random.randn(2,2))
self.bias = bm.TrainVar(bm.random.randn(2))
def reset_state(self, *args):
self.neu.reset_state(self.neu.mode)
self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)
def update(self, data):
spike = self.neu(self.weight @ data + self.bias) # [batch, 2]
self.spike_buffer.update(spike)
spike_delay = self.spike_buffer.retrieve(self.delay_len) # [batch, 2]
return spike_delay
#%% Create network and fake data
print('Creating network... ')
with bm.training_environment():
model = Network()
optimizer = bp.optim.Adam(lr=1.0, train_vars=model.train_vars().unique())
print('Creating data... ')
train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1,-1]]),
np.random.randn(100, 2) + np.array([[ 1, 1]])], axis=0) # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32),
bm.ones(100, dtype=bm.int32)], axis=0) # [batch]
#%% Training functions
def loss_fun(x_single, y_single):
'''
Inputs:
x_single: [feature]
y_single: [1]
'''
indices = np.arange(snn_latency) # sequence length
model.reset_state()
spike = bm.for_loop(functools.partial(model.step_run, data=x_single), indices) # [length, batch=1, 2], float32
firerate = bm.sum(spike, axis=0) + 1.0e-6 # [batch=1, 2]
predict = bm.log(firerate / bm.sum(firerate)) # log-prababilities, [batch=1, n_class]
loss = bp.losses.nll_loss(-predict, y_single) # scalar
acc = bm.mean(predict.argmax(-1) == y_single) # scalar
return loss, acc
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
def grad_fun(last_grad, x_y_single):
'''
Inputs:
last_grad: PyTree of gradients of each trainable parameter.
x_y_single: tuple of ([feature], scalar), a single training sample.
'''
x_single, y_single = x_y_single # [feature], scalar
grads, loss, acc = grad_f(x_single, y_single[None]) # PyTree of gradients, scalar, scalar
new_grad = jax.tree_map(bm.add, last_grad, grads) # accumulate gradients
return new_grad, (loss, acc)
@bm.jit
def train(x_batch, y_batch):
'''
Inputs:
x_batch: [batch, feature]
y_batch: [batch]
'''
train_vars = model.train_vars().unique()
# Gradient accumulation
grads = jax.tree_map(bm.zeros_like, train_vars)
grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch)) # PyTree of gradients, [batch], [batch]
optimizer.update(grads)
loss = losses.mean() # scalar
acc = acces.mean() # scalar
return loss, acc
#%% Start training
print('Start training...')
train_epochs = 10
train_loss = bm.zeros(train_epochs, dtype=bm.float_)
train_acc = bm.zeros(train_epochs, dtype=bm.float_)
for e in range(train_epochs):
# with jax.disable_jit():
train_loss[e], train_acc[e] = train(train_data, train_label)
print("Epoch {}, train_loss={:.3f}, train_acc={:.2f}%".format(e, train_loss[e], train_acc[e]*100.0))
print('Done!')
Outputs:
Creating network...
Creating data...
Start training...
Epoch 0, train_loss=4.572, train_acc=50.00%
Epoch 1, train_loss=1.373, train_acc=77.00%
Epoch 2, train_loss=0.757, train_acc=87.00%
Epoch 3, train_loss=0.767, train_acc=89.50%
Epoch 4, train_loss=0.636, train_acc=91.00%
Epoch 5, train_loss=0.642, train_acc=91.00%
Epoch 6, train_loss=0.568, train_acc=89.50%
Epoch 7, train_loss=0.570, train_acc=89.50%
Epoch 8, train_loss=0.576, train_acc=90.00%
Epoch 9, train_loss=0.582, train_acc=91.00%
Done!
Thanks for the report. The rotation
mode may be fixed by sometimes before. But i will check whether the gradients is correct.
Thanks for the report. The
rotation
mode may be fixed by sometimes before. But i will check whether the gradients is correct.
This is also what I hope to know. How to check gradients in BrainPy?
I write a simple code to check whether the gradients are the same. The answer is yes.
import functools
import jax
import numpy as np
import brainpy as bp
import brainpy.math as bm
snn_latency = 20
dt = 1.0
bm.clear_buffer_memory()
bm.set(float_=bm.float32, mode=bm.training_mode)
bm.set_platform('cpu')
bm.set_dt(dt)
# %% Network definition
class Network(bp.DynSysGroup):
def __init__(self, method):
super().__init__()
self.neu = bp.dyn.Lif(size=2, V_rest=0.0, V_reset=0.0, V_th=1.0, spk_fun=bm.surrogate.Arctan())
self.delay_len = 2
self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method=method)
self.weight = bm.TrainVar(bm.random.randn(2, 2))
self.bias = bm.TrainVar(bm.random.randn(2))
def reset_state(self, *args):
self.neu.reset_state(self.neu.mode)
self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)
def update(self, data):
spike = self.neu(self.weight @ data + self.bias) # [batch, 2]
self.spike_buffer.update(spike)
spike_delay = self.spike_buffer.retrieve(self.delay_len) # [batch, 2]
return spike_delay
def train1(method='rotation'):
# %% Create network and fake data
model = Network(method)
optimizer = bp.optim.Adam(lr=1.0, train_vars=model.train_vars().unique())
# %% Training functions
def loss_fun(x_single, y_single):
'''
Inputs:
x_single: [feature]
y_single: [1]
'''
indices = np.arange(snn_latency) # sequence length
model.reset_state()
spike = bm.for_loop(functools.partial(model.step_run, data=x_single), indices) # [length, batch=1, 2], float32
firerate = bm.sum(spike, axis=0) + 1.0e-6 # [batch=1, 2]
predict = bm.log(firerate / bm.sum(firerate)) # log-prababilities, [batch=1, n_class]
loss = bp.losses.nll_loss(-predict, y_single) # scalar
acc = bm.mean(predict.argmax(-1) == y_single) # scalar
return loss, acc
def grad_fun(last_grad, x_y_single):
'''
Inputs:
last_grad: PyTree of gradients of each trainable parameter.
x_y_single: tuple of ([feature], scalar), a single training sample.
'''
x_single, y_single = x_y_single # [feature], scalar
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
grads, loss, acc = grad_f(x_single, y_single[None]) # PyTree of gradients, scalar, scalar
new_grad = jax.tree_map(bm.add, last_grad, grads) # accumulate gradients
return new_grad, (loss, acc)
@bm.jit
def train(x_batch, y_batch):
'''
Inputs:
x_batch: [batch, feature]
y_batch: [batch]
'''
train_vars = model.train_vars().unique()
# Gradient accumulation
grads = jax.tree_map(bm.zeros_like, train_vars)
grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch)) # PyTree of gradients, [batch], [batch]
optimizer.update(grads)
return grads
return train
train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1, -1]]),
np.random.randn(100, 2) + np.array([[1, 1]])], axis=0) # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32),
bm.ones(100, dtype=bm.int32)], axis=0) # [batch]
bm.random.seed(0)
bm.clear_name_cache()
f1 = train1('rotation')
bm.random.seed(0)
bm.clear_name_cache()
f2 = train1('concat')
for e in range(10):
# with jax.disable_jit():
grad1 = f1(train_data, train_label)
grad2 = f2(train_data, train_label)
print(jax.tree_map(bm.allclose, grad1, grad2))
Hi, I actually hope to know where are the gradient stored in BrainPy. For example, in PyTorch there is a grad
field in the trained parameters which stored the gradient values. Is there a similar field in BrainPy variables?
The gradients do not have a fixed place to store. It is only returned after the function is computed. For the following example, the gradient has stored as grads
:
# "grad_vars" specify the target to compute gradients
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
# "grads" return as the function output
grads, loss, acc = grad_f(x_single, y_single[None])
Thank you very much for the information!