brainpy/examples

TypeError in mnist_lif_readout.py

Closed this issue · 2 comments

Came from #9. Thanks for solving that problem!
There's another error in the same example:

Namespace(T=100, platform='cpu', batch=64, epochs=15, out_dir='./logs', lr=0.001, tau=2.0)
Traceback (most recent call last):
  File "/Users/pei/project/computational-neuroscience/mnist_lif_readout.py", line 103, in <module>
    l, correct_num = train(X, Y)
                     ^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/jit.py", line 208, in __call__
    rets = self._get_transform(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/jit.py", line 155, in _get_transform
    self._dyn_vars, rets = evaluate_dyn_vars(
                           ^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/tools.py", line 101, in evaluate_dyn_vars
    rets = jax.eval_shape(f2, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/project/computational-neuroscience/mnist_lif_readout.py", line 86, in train
    grads, l, n = grad_fun(xs, ys)
                  ^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py", line 209, in __call__
    rets = self._transform(
           ^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py", line 133, in _f_grad_with_aux_to_transform
    outputs = self.target(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/project/computational-neuroscience/mnist_lif_readout.py", line 69, in loss_fun
    out_fr = jnp.mean(outs, axis=0)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/jax/_src/numpy/reductions.py", line 319, in mean
    return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/jax/_src/numpy/reductions.py", line 351, in _mean
    sum(a, axis, dtype=computation_dtype, keepdims=keepdims, where=where),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/jax/_src/numpy/reductions.py", line 222, in sum
    return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/jax/_src/numpy/reductions.py", line 212, in _reduce_sum
    return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/jax/_src/numpy/reductions.py", line 94, in _reduction
    a = a if isinstance(a, Array) else lax_internal.asarray(a)
                                       ^^^^^^^^^^^^^^^^^^^^^^^
TypeError: asarray: expected ArrayLike, got Traced<ShapedArray(float32[100,64,10])>with<DynamicJaxprTrace(level=5/0)> of type <class 'brainpy._src.math.ndarray.Array'>.

Thanks for the report. It can be easily correct it by changing the line

out_fr = jnp.mean(outs, axis=0)

as

out_fr = bm.mean(outs, axis=0)

Thanks for solving this! I'll make a PR to correct similar errors in this file.