Issue: vmapped CartPole input shape does not match
DriesSmit opened this issue · 1 comments
DriesSmit commented
Hello there. I am trying to run a vmapped
CartPole step function. My environment state inputs are of the shape:
env_state:
[executor/0] x: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] x_dot: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] theta: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] theta_dot: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
When I run jnp.array([env_state.x, env_state.x_dot, env_state.theta, env_state.theta_dot])
on the state, before the environment step, and get out:
Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=2/0)>
However when I try to run the step function I get:
obs, env_state, rewards, done, _ = self.env.step(key_step, env_state, action, self.env_params)
[executor/0] File "/mava/lib/python3.8/site-packages/gymnax/environments/environment.py", line 38, in step
[executor/0] obs_st, state_st, reward, done, info = self.step_env(
[executor/0] File "/mava/lib/python3.8/site-packages/gymnax/environments/classic_control/cartpole.py", line 83, in step_env
[executor/0] lax.stop_gradient(self.get_obs(state)),
[executor/0] File "/mava/lib/python3.8/site-packages/gymnax/environments/classic_control/cartpole.py", line 108, in get_obs
[executor/0] return jnp.array([state.x, state.x_dot, state.theta, state.theta_dot])
[executor/0] File "/mava/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1889, in array
[executor/0] out = stack([asarray(elt, dtype=dtype) for elt in object])
[executor/0] File "/mava/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1634, in stack
[executor/0] raise ValueError("All input arrays must have the same shape.")
[executor/0] ValueError: All input arrays must have the same shape.
Do you have any idea what might be causing this issue? Is the shapes somehow changing inside the step function? Thanks.
DriesSmit commented
Apologies, it was not to do with the environment. It was because I was passing an action logit array of size 2 to the environment.