brainpy/BrainPy

Questions about backpropagation through delay variables

CloudyDory opened this issue · 6 comments

In #626 we mentioned that the rotation method in delay variables does not implement an autograd functionality. However I have tested this in training and found that the parameters can be trained normally. Is there a misunderstanding on the issue?

import functools
import numpy as np
import jax
import brainpy as bp
import brainpy.math as bm

snn_latency = 20
dt = 1.0

bm.clear_buffer_memory()
bm.set(float_=bm.float32)
bm.set_platform('cpu')
bm.set_dt(dt)

#%% Network definition
class Network(bp.DynSysGroup):
    def __init__(self):
        super().__init__()
        self.neu = bp.dyn.Lif(size=2, V_rest=0.0, V_reset=0.0, V_th=1.0, spk_fun=bm.surrogate.Arctan())
        self.delay_len = 2
        self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method='rotation')
        self.weight = bm.TrainVar(bm.random.randn(2,2))
        self.bias = bm.TrainVar(bm.random.randn(2))
    
    def reset_state(self, *args):
        self.neu.reset_state(self.neu.mode)
        self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)
    
    def update(self, data):
        spike = self.neu(self.weight @ data + self.bias)  # [batch, 2] 
        self.spike_buffer.update(spike)
        spike_delay = self.spike_buffer.retrieve(self.delay_len)  # [batch, 2] 
        return spike_delay

#%% Create network and fake data
print('Creating network... ')
with bm.training_environment():
    model = Network()
    optimizer = bp.optim.Adam(lr=1.0, train_vars=model.train_vars().unique())

print('Creating data... ')
train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1,-1]]),
                             np.random.randn(100, 2) + np.array([[ 1, 1]])], axis=0)  # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32), 
                              bm.ones(100, dtype=bm.int32)], axis=0)  # [batch]

#%% Training functions
def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [feature]
        y_single: [1]
    '''
    indices = np.arange(snn_latency)  # sequence length
    
    model.reset_state()
    spike = bm.for_loop(functools.partial(model.step_run, data=x_single), indices)  # [length, batch=1, 2], float32
    firerate = bm.sum(spike, axis=0) + 1.0e-6  # [batch=1, 2]
    predict = bm.log(firerate / bm.sum(firerate))  # log-prababilities, [batch=1, n_class]
    
    loss = bp.losses.nll_loss(-predict, y_single)  # scalar
    acc = bm.mean(predict.argmax(-1) == y_single)  # scalar
    return loss, acc

grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)

def grad_fun(last_grad, x_y_single):
    '''
    Inputs:
        last_grad: PyTree of gradients of each trainable parameter.
        x_y_single: tuple of ([feature], scalar), a single training sample.
    '''
    x_single, y_single = x_y_single  # [feature], scalar
    grads, loss, acc = grad_f(x_single, y_single[None])  # PyTree of gradients, scalar, scalar
    new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients
    return new_grad, (loss, acc)

@bm.jit
def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, feature]
        y_batch: [batch]
    '''
    train_vars = model.train_vars().unique()
    
    # Gradient accumulation
    grads = jax.tree_map(bm.zeros_like, train_vars)
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))  # PyTree of gradients, [batch], [batch]
    optimizer.update(grads)
    
    loss = losses.mean()  # scalar
    acc = acces.mean()    # scalar
    return loss, acc

#%% Start training
print('Start training...')
train_epochs = 10
train_loss = bm.zeros(train_epochs, dtype=bm.float_)
train_acc = bm.zeros(train_epochs, dtype=bm.float_)

for e in range(train_epochs):
    # with jax.disable_jit():
    train_loss[e], train_acc[e] = train(train_data, train_label)
    print("Epoch {}, train_loss={:.3f}, train_acc={:.2f}%".format(e, train_loss[e], train_acc[e]*100.0))
    
print('Done!')

Outputs:

Creating network... 
Creating data... 
Start training...
Epoch 0, train_loss=4.572, train_acc=50.00%
Epoch 1, train_loss=1.373, train_acc=77.00%
Epoch 2, train_loss=0.757, train_acc=87.00%
Epoch 3, train_loss=0.767, train_acc=89.50%
Epoch 4, train_loss=0.636, train_acc=91.00%
Epoch 5, train_loss=0.642, train_acc=91.00%
Epoch 6, train_loss=0.568, train_acc=89.50%
Epoch 7, train_loss=0.570, train_acc=89.50%
Epoch 8, train_loss=0.576, train_acc=90.00%
Epoch 9, train_loss=0.582, train_acc=91.00%
Done!

Thanks for the report. The rotation mode may be fixed by sometimes before. But i will check whether the gradients is correct.

Thanks for the report. The rotation mode may be fixed by sometimes before. But i will check whether the gradients is correct.

This is also what I hope to know. How to check gradients in BrainPy?

I write a simple code to check whether the gradients are the same. The answer is yes.

import functools

import jax
import numpy as np

import brainpy as bp
import brainpy.math as bm

snn_latency = 20
dt = 1.0

bm.clear_buffer_memory()
bm.set(float_=bm.float32, mode=bm.training_mode)
bm.set_platform('cpu')
bm.set_dt(dt)


# %% Network definition
class Network(bp.DynSysGroup):
  def __init__(self, method):
    super().__init__()
    self.neu = bp.dyn.Lif(size=2, V_rest=0.0, V_reset=0.0, V_th=1.0, spk_fun=bm.surrogate.Arctan())
    self.delay_len = 2
    self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method=method)
    self.weight = bm.TrainVar(bm.random.randn(2, 2))
    self.bias = bm.TrainVar(bm.random.randn(2))

  def reset_state(self, *args):
    self.neu.reset_state(self.neu.mode)
    self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)

  def update(self, data):
    spike = self.neu(self.weight @ data + self.bias)  # [batch, 2]
    self.spike_buffer.update(spike)
    spike_delay = self.spike_buffer.retrieve(self.delay_len)  # [batch, 2]
    return spike_delay


def train1(method='rotation'):
  # %% Create network and fake data
  model = Network(method)
  optimizer = bp.optim.Adam(lr=1.0, train_vars=model.train_vars().unique())

  # %% Training functions
  def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [feature]
        y_single: [1]
    '''
    indices = np.arange(snn_latency)  # sequence length

    model.reset_state()
    spike = bm.for_loop(functools.partial(model.step_run, data=x_single), indices)  # [length, batch=1, 2], float32
    firerate = bm.sum(spike, axis=0) + 1.0e-6  # [batch=1, 2]
    predict = bm.log(firerate / bm.sum(firerate))  # log-prababilities, [batch=1, n_class]

    loss = bp.losses.nll_loss(-predict, y_single)  # scalar
    acc = bm.mean(predict.argmax(-1) == y_single)  # scalar
    return loss, acc

  def grad_fun(last_grad, x_y_single):
    '''
    Inputs:
        last_grad: PyTree of gradients of each trainable parameter.
        x_y_single: tuple of ([feature], scalar), a single training sample.
    '''
    x_single, y_single = x_y_single  # [feature], scalar
    grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
    grads, loss, acc = grad_f(x_single, y_single[None])  # PyTree of gradients, scalar, scalar
    new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients
    return new_grad, (loss, acc)

  @bm.jit
  def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, feature]
        y_batch: [batch]
    '''
    train_vars = model.train_vars().unique()
    # Gradient accumulation
    grads = jax.tree_map(bm.zeros_like, train_vars)
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))  # PyTree of gradients, [batch], [batch]
    optimizer.update(grads)
    return grads

  return train


train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1, -1]]),
                             np.random.randn(100, 2) + np.array([[1, 1]])], axis=0)  # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32),
                              bm.ones(100, dtype=bm.int32)], axis=0)  # [batch]

bm.random.seed(0)
bm.clear_name_cache()
f1 = train1('rotation')

bm.random.seed(0)
bm.clear_name_cache()
f2 = train1('concat')

for e in range(10):
  # with jax.disable_jit():
  grad1 = f1(train_data, train_label)
  grad2 = f2(train_data, train_label)
  print(jax.tree_map(bm.allclose, grad1, grad2))

Hi, I actually hope to know where are the gradient stored in BrainPy. For example, in PyTorch there is a grad field in the trained parameters which stored the gradient values. Is there a similar field in BrainPy variables?

The gradients do not have a fixed place to store. It is only returned after the function is computed. For the following example, the gradient has stored as grads:

# "grad_vars" specify the target to compute gradients
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
# "grads" return as the function output
grads, loss, acc = grad_f(x_single, y_single[None])

Thank you very much for the information!