How to speed up rollout?
Closed this issue · 1 comments
jkewang commented
I found a test function in the code you provided:
def test_rollout_result_matches_dynamics_and_reward(self):
env = _env.MultiAgentEnvironment(
dynamics_model=dynamics.DeltaGlobal(),
config=_config.EnvironmentConfig(
init_steps=2,
max_num_objects=self.dataset_config.max_num_objects,
),
)
def _expert_action_fn(state, obs, rng):
del obs, rng
prev_sim_traj = datatypes.dynamic_slice(
state.sim_trajectory, state.timestep, 1, axis=-1
)
logged_next_traj = datatypes.dynamic_slice(
state.log_trajectory, state.timestep + 1, 1, axis=-1
)
combined_traj = jax.tree_map(
lambda x, y: jnp.concatenate([x, y], axis=-1),
prev_sim_traj,
logged_next_traj,
)
return env.dynamics.inverse(
combined_traj, metadata=state.object_metadata, timestep=0
)
result = _env.rollout(
self.state_t0,
expert.create_expert_actor(env.dynamics),
env,
rng=jax.random.PRNGKey(0),
rollout_num_steps=2,
)
......
I tested the speed of rollout and found it to be very slow. I think it is because the rollout function does not go through jax.jit, but when I use jax_rollout, an error message appears:
2023-12-18 19:58:21.803233: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:440] Loaded cuDNN version 8904
2023-12-18 19:58:21.885778: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking
Traceback (most recent call last):
File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/api_util.py", line 581, in shaped_abstractify
return _shaped_abstractify_handlers[type(x)](x)
KeyError: <class 'waymax.env.base_environment.BaseEnvironment'>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/Desktop/workspaces/code/waymax/jingke/single_agent/rollout.py", line 85, in <module>
rollout_output = jax_rollout(init_state, expert_actor, env, rng=jax.random.PRNGKey(0), rollout_num_steps=5)
File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/pjit.py", line 253, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/pjit.py", line 161, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/pjit.py", line 477, in common_infer_params
avals.append(shaped_abstractify(a))
File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/api_util.py", line 583, in shaped_abstractify
return _shaped_abstractify_slow(x)
File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/api_util.py", line 572, in _shaped_abstractify_slow
raise TypeError(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Cannot interpret value of type <class 'waymax.env.base_environment.BaseEnvironment'> as an abstract array; it does not have a dtype attribute
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/Desktop/workspaces/code/waymax/jingke/single_agent/rollout.py", line 85, in <module>
rollout_output = jax_rollout(init_state, expert_actor, env, rng=jax.random.PRNGKey(0), rollout_num_steps=5)
TypeError: Cannot interpret value of type <class 'waymax.env.base_environment.BaseEnvironment'> as an abstract array; it does not have a dtype attribute
2023-12-18 19:58:22.882393: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:465] TfrtCpuClient destroyed.
Process finished with exit code 1
I want to know if my understanding of rollout is wrong? Is there a more efficient way to rollout?
justinjfu commented
All arguments to a jitted function need to be JAX arrays (or structures of arrays). The general fix would be to use functools.partial
to fill in the non-JAX arguments first (they will be statically embedded into the computation graph), then use jax.jit
on the remaining arguments.
In your case, you would need to do something like:
rollout_fn = functools.partial(_env.rollout, expert_actor=<>, env=<>, rollout_num_steps=2)
result = jax.jit(rollout_fn)(state, rng_key)