mlech26l/gigastep

benchmark.py silently errors

Opened this issue · 0 comments

Hi, first off thanks for making such a cool environment in JAX, I think this will be great for the community!

I wanted to investigate the speed of your environment and so tried running the benchmark.py file with no changes. It reported very fast speeds with a hardcoded policy (~2e9 sps), but upon further investigation I found that the env.step function was silently erroring inside the scan. I realise this is a highly volatile repo at the moment so I was wondering if this script is meant to be working?

Steps to reproduce

  • Run benchmark.py with a debugger, setting a breakpoint at line 52
  • Try manually running the command on line 52 (env.v_step(states, actions, key)) in the debugger

Error

File ".../gigastep/gigastep/gigastep_env.py", line 356, in step
    agent_states = v_step(agent_states, actions, self._per_agent_thrust)
...
ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (2 of them) had size 2, e.g. axis 0 of argument action of type float32[2,3];
  * one axis had size 18: axis 0 of argument state of type float32[18]