TypeError when running a simulation
PikaPei opened this issue · 2 comments
Hello!
I am new to BrainPy and find it a great simulation tool!
But when I played with it, I met some errors and couldn't find solutions in the documentation.
I want to make two variables, var1
and var2
, each receiving spike inputs from distinct SpikeTimeGroups
.
Each variable shows simple exponential decay dynamics with different time constants.
I'm not sure if I'm doing something wrong, and I would appreciate any advice.
Here is my code and my BrainPy version is 2.4.6.post5
,
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
class Model(bp.NeuGroup):
def __init__(self, size, var1_tau=500, var2_tau=1000):
super().__init__(size=size)
self.var1_pre = bp.neurons.SpikeTimeGroup(1, times=[100], indices=[0])
self.var2_pre = bp.neurons.SpikeTimeGroup(1, times=[200], indices=[0])
self.var1_tau = var1_tau
self.var1 = bm.Variable(bm.zeros(self.num))
self.var2_tau = var2_tau
self.var2 = bm.Variable(bm.zeros(self.num))
self.integral = bp.odeint(bp.JointEq(self.dvar1, self.dvar2), method="exp_auto")
def dvar1(self, var1, t):
dvar1dt = -var1 / self.var1_tau
return dvar1dt
def dvar2(self, var2, t):
dvar2dt = -var2 / self.var2_tau
return dvar2dt
def update(self):
t = bp.share["t"]
dt = bp.share["dt"]
self.var1_pre.update()
self.var2_pre.update()
self.var1.value = self.integral(self.var1, t, dt=dt) + self.var1_pre.spike
self.var2.value = self.integral(self.var2, t, dt=dt) + self.var2_pre.spike
def run(self, duration):
self.runner = bp.DSRunner(
self,
monitors=["var1", "var2"],
)
self.runner.run(duration)
if __name__ == "__main__":
model = Model(1)
model.run(1000)
plt.plot(model.runner.mon.ts, model.runner.mon.var1)
plt.plot(model.runner.mon.ts, model.runner.mon.var2)
plt.show()
The error is:
Traceback (most recent call last):
File "/Users/pei/project/comp-neuro-brainpy/test.py", line 49, in <module>
model.run(1000)
File "/Users/pei/project/comp-neuro-brainpy/test.py", line 44, in run
self.runner.run(duration)
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/runners.py", line 512, in run
return self.predict(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/runners.py", line 485, in predict
outputs, hists = self._predict(indices, *inputs, shared_args=shared_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/runners.py", line 539, in _predict
outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/runners.py", line 662, in _fun_predict
return bm.for_loop(self._step_func_predict,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py", line 877, in for_loop
rets = jax.eval_shape(transform, operands)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py", line 730, in call
return jax.lax.scan(f=fun2scan,
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py", line 721, in fun2scan
results = body_fun(*x, **unroll_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/runners.py", line 628, in _step_func_predict
out = self.target(*x)
^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/dynsys.py", line 378, in __call__
ret = self.update(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/dynsys.py", line 330, in _compatible_update
return update_fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/project/comp-neuro-brainpy/test.py", line 35, in update
self.var1.value = self.integral(self.var1, t, dt=dt) + self.var1_pre.spike
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/integrators/ode/base.py", line 114, in __call__
new_vars = self.integral(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/integrators/ode/exponential.py", line 332, in integral_func
r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/integrators/ode/exponential.py", line 360, in integral
linear, derivative = value_and_grad(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py", line 209, in __call__
rets = self._transform(
^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py", line 771, in grad_fun
y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py", line 150, in _f_grad_without_aux_to_transform
output = self.target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Model.dvar1() missing 1 required positional argument: 't'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Thank you!
Thanks for the question. The error is caused because you are using the joint equation while giving the parameters independently.
self.integral = bp.odeint(bp.JointEq(self.dvar1, self.dvar2), method="exp_auto")
One way to solve this issue is modifying your update
function as:
self.var1.value, self.var2.value = self.integral(self.var1, self.var2, t, dt=dt)
self.var1 += self.var1_pre.spike
self.var2 += self.var2_pre.spike
I see. It works now!
Thank you for the helpful answer!