"TypeError: true_fun and false_fun output must have same type structure" in "brainpy.math.cond" function
CloudyDory opened this issue · 2 comments
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
- If applicable, include full error messages/tracebacks.
I am trying to implement a custom short-term synaptic depression model in Brainpy. A section of my code is trying to accomplish this task:
if bm.any(spike):
x, x_old, d = func(spike, t, x, x_old, d)
However, Brainpy relies on jax, and if the variable spike
is being traced at runtime, we cannot directly use it in Python control flow. I therefore try to use the following workaround:
import brainpy.math as bm
t = 10.0
x = bm.Variable(bm.zeros(5, dtype=bm.float32))
x_old = bm.Variable(bm.zeros(5, dtype=bm.float32))
d = bm.Variable(bm.ones(5, dtype=bm.float32))
spike = bm.Variable(bm.random.randn(5) > 0)
def do_nothing1(spike, t, x, x_old, d):
return x, x_old, d
def func1(spike, t, x, x_old, d):
x_old = bm.where(spike, x, x_old)
x = bm.where(spike, t, x)
d_next = 1.0 - (1.0 - d*0.9) * bm.exp(-(x-x_old)/100.0)
d = bm.where(spike, d_next, d)
return x, x_old, d
x.value, x_old.value, d.value = bm.cond(bm.any(spike), func1, do_nothing1, (t, spike.value, x.value, x_old.value, d.value))
But running this code generates the following error:
Traceback (most recent call last):
File ~\miniconda3\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec
exec(code, globals, locals)
File d:\xxx\untitled2.py:27
x.value, x_old.value, d.value = bm.cond(bm.any(spike), func1, do_nothing1, (t, spike.value, x.value, x_old.value, d.value))
File ~\miniconda3\Lib\site-packages\brainpy\_src\math\object_transform\controls.py:539 in cond
dyn_vars, rets = evaluate_dyn_vars(
File ~\miniconda3\Lib\site-packages\brainpy\_src\math\object_transform\_tools.py:97 in evaluate_dyn_vars
rets = jax.eval_shape(f2, *args, **kwargs)
File ~\miniconda3\Lib\site-packages\jax\_src\traceback_util.py:166 in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File ~\miniconda3\Lib\site-packages\jax\_src\api.py:2807 in eval_shape
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
File ~\miniconda3\Lib\site-packages\jax\_src\interpreters\partial_eval.py:670 in abstract_eval_fun
_, avals_out, _ = trace_to_jaxpr_dynamic(
File ~\miniconda3\Lib\site-packages\jax\_src\profiler.py:314 in wrapper
return func(*args, **kwargs)
File ~\miniconda3\Lib\site-packages\jax\_src\interpreters\partial_eval.py:2155 in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File ~\miniconda3\Lib\site-packages\jax\_src\interpreters\partial_eval.py:2177 in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File ~\miniconda3\Lib\site-packages\jax\_src\linear_util.py:188 in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File ~\miniconda3\Lib\site-packages\jax\_src\linear_util.py:188 in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File ~\miniconda3\Lib\site-packages\brainpy\_src\math\object_transform\controls.py:452 in call_fun
return jax.lax.cond(pred, _true_fun, _false_fun, dyn_vars.dict_data(), *operands)
File ~\miniconda3\Lib\site-packages\jax\_src\traceback_util.py:166 in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File ~\miniconda3\Lib\site-packages\jax\_src\lax\control_flow\conditionals.py:292 in cond
return _cond(*args, **kwargs)
File ~\miniconda3\Lib\site-packages\jax\_src\lax\control_flow\conditionals.py:252 in _cond
_check_tree_and_avals("true_fun and false_fun output",
File ~\miniconda3\Lib\site-packages\jax\_src\lax\control_flow\common.py:198 in _check_tree_and_avals
raise TypeError(
TypeError: true_fun and false_fun output must have same type structure, got PyTreeDef(({}, (CustomNode(Array[None], [*]), CustomNode(Array[None], [*]), CustomNode(Array[None], [*])))) and PyTreeDef(({}, (*, *, *))).
Yet, if I translate the code to use jax, it works fine:
import jax
import jax.numpy as jnp
t = 10.0
x = jnp.zeros(5, dtype=jnp.float32)
x_old = jnp.zeros(5, dtype=jnp.float32)
d = jnp.ones(5, dtype=jnp.float32)
spike = jax.random.normal(jax.random.PRNGKey(0), (5,)) > 0
@jax.jit
def do_nothing2(spike, t, x, x_old, d):
return x, x_old, d
@jax.jit
def func2(spike, t, x, x_old, d):
x_old = jnp.where(spike, x, x_old)
x = jnp.where(spike, t, x)
d_next = 1.0 - (1.0 - d*0.9) * jnp.exp(-(x-x_old)/50.0)
d = jnp.where(spike, d_next, d)
return x, x_old, d
x, x_old, d = jax.lax.cond(jnp.any(spike), func2, do_nothing2, t, spike, x, x_old, d)
I hope to know what is the cause of this error, and how to deal with it? I am using Brainpy version 2.4.4.post1, and jax version 0.4.14.
Thanks!
Thanks for the report.
I recommend using the following code:
import brainpy.math as bm
t = 10.0
x = bm.Variable(bm.zeros(5, dtype=bm.float32))
x_old = bm.Variable(bm.zeros(5, dtype=bm.float32))
d = bm.Variable(bm.ones(5, dtype=bm.float32))
spike = bm.Variable(bm.random.randn(5) > 0)
def do_nothing1(*args):
return
def func1(t):
x_old.value = bm.where(spike, x, x_old)
x.value = bm.where(spike, t, x)
d_next = 1.0 - (1.0 - d * 0.9) * bm.exp(-(x - x_old) / 100.0)
d.value = bm.where(spike, d_next, d)
bm.cond(bm.any(spike), func1, do_nothing1, t)
If you are using Variable
, please use it like this style.
Thank you, your suggested solution works!