TypeError in mnist_lif_readout.py
Closed this issue · 2 comments
PikaPei commented
Came from #9. Thanks for solving that problem!
There's another error in the same example:
Namespace(T=100, platform='cpu', batch=64, epochs=15, out_dir='./logs', lr=0.001, tau=2.0)
Traceback (most recent call last):
File "/Users/pei/project/computational-neuroscience/mnist_lif_readout.py", line 103, in <module>
l, correct_num = train(X, Y)
^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/jit.py", line 208, in __call__
rets = self._get_transform(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/jit.py", line 155, in _get_transform
self._dyn_vars, rets = evaluate_dyn_vars(
^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/tools.py", line 101, in evaluate_dyn_vars
rets = jax.eval_shape(f2, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/project/computational-neuroscience/mnist_lif_readout.py", line 86, in train
grads, l, n = grad_fun(xs, ys)
^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py", line 209, in __call__
rets = self._transform(
^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py", line 133, in _f_grad_with_aux_to_transform
outputs = self.target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/project/computational-neuroscience/mnist_lif_readout.py", line 69, in loss_fun
out_fr = jnp.mean(outs, axis=0)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/jax/_src/numpy/reductions.py", line 319, in mean
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/jax/_src/numpy/reductions.py", line 351, in _mean
sum(a, axis, dtype=computation_dtype, keepdims=keepdims, where=where),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/jax/_src/numpy/reductions.py", line 222, in sum
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/jax/_src/numpy/reductions.py", line 212, in _reduce_sum
return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/jax/_src/numpy/reductions.py", line 94, in _reduction
a = a if isinstance(a, Array) else lax_internal.asarray(a)
^^^^^^^^^^^^^^^^^^^^^^^
TypeError: asarray: expected ArrayLike, got Traced<ShapedArray(float32[100,64,10])>with<DynamicJaxprTrace(level=5/0)> of type <class 'brainpy._src.math.ndarray.Array'>.
chaoming0625 commented
Thanks for the report. It can be easily correct it by changing the line
out_fr = jnp.mean(outs, axis=0)
as
out_fr = bm.mean(outs, axis=0)
PikaPei commented
Thanks for solving this! I'll make a PR to correct similar errors in this file.