benchmark.py silently errors
Opened this issue · 0 comments
MichaelTMatthews commented
Hi, first off thanks for making such a cool environment in JAX, I think this will be great for the community!
I wanted to investigate the speed of your environment and so tried running the benchmark.py
file with no changes. It reported very fast speeds with a hardcoded policy (~2e9 sps), but upon further investigation I found that the env.step
function was silently erroring inside the scan. I realise this is a highly volatile repo at the moment so I was wondering if this script is meant to be working?
Steps to reproduce
- Run benchmark.py with a debugger, setting a breakpoint at line 52
- Try manually running the command on line 52 (
env.v_step(states, actions, key)
) in the debugger
Error
File ".../gigastep/gigastep/gigastep_env.py", line 356, in step
agent_states = v_step(agent_states, actions, self._per_agent_thrust)
...
ValueError: vmap got inconsistent sizes for array axes to be mapped:
* most axes (2 of them) had size 2, e.g. axis 0 of argument action of type float32[2,3];
* one axis had size 18: axis 0 of argument state of type float32[18]