ndarray or scalar arguments, got <class 'list'> at position 0.
atarilover123 opened this issue · 3 comments
atarilover123 commented
I'm getting this error on the "initialize model" cell.
UnfilteredStackTrace: TypeError: broadcast_to requires ndarray or scalar arguments, got <class 'list'> at position 0.
asinghka commented
I changed
jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
to
jnp.broadcast_to(last_sample_z, z_vals[..., :1].shape)
in line 107
of model_utils.py
and that fixed it for me.
CurtinComputing commented
Hi blackz5,
It dose not work for me. Did you try it on the colab?
asinghka commented
No, I have run it locally and changed the file as describe above.