waymo-research/waymax

How to speed up rollout?

Closed this issue · 1 comments

I found a test function in the code you provided:

    def test_rollout_result_matches_dynamics_and_reward(self):
        env = _env.MultiAgentEnvironment(
            dynamics_model=dynamics.DeltaGlobal(),
            config=_config.EnvironmentConfig(
                init_steps=2,
                max_num_objects=self.dataset_config.max_num_objects,
            ),
        )

        def _expert_action_fn(state, obs, rng):
            del obs, rng
            prev_sim_traj = datatypes.dynamic_slice(
                state.sim_trajectory, state.timestep, 1, axis=-1
            )
            logged_next_traj = datatypes.dynamic_slice(
                state.log_trajectory, state.timestep + 1, 1, axis=-1
            )
            combined_traj = jax.tree_map(
                lambda x, y: jnp.concatenate([x, y], axis=-1),
                prev_sim_traj,
                logged_next_traj,
            )
            return env.dynamics.inverse(
                combined_traj, metadata=state.object_metadata, timestep=0
            )

        result = _env.rollout(
            self.state_t0,
            expert.create_expert_actor(env.dynamics),
            env,
            rng=jax.random.PRNGKey(0),
            rollout_num_steps=2,
        )
        ......

I tested the speed of rollout and found it to be very slow. I think it is because the rollout function does not go through jax.jit, but when I use jax_rollout, an error message appears:

2023-12-18 19:58:21.803233: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:440] Loaded cuDNN version 8904
2023-12-18 19:58:21.885778: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking
Traceback (most recent call last):
  File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/api_util.py", line 581, in shaped_abstractify
    return _shaped_abstractify_handlers[type(x)](x)
KeyError: <class 'waymax.env.base_environment.BaseEnvironment'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/Desktop/workspaces/code/waymax/jingke/single_agent/rollout.py", line 85, in <module>
    rollout_output = jax_rollout(init_state, expert_actor, env, rng=jax.random.PRNGKey(0), rollout_num_steps=5)
  File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/pjit.py", line 253, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/pjit.py", line 161, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/pjit.py", line 477, in common_infer_params
    avals.append(shaped_abstractify(a))
  File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/api_util.py", line 583, in shaped_abstractify
    return _shaped_abstractify_slow(x)
  File "/home/conda/envs/waymo/lib/python3.10/site-packages/jax/_src/api_util.py", line 572, in _shaped_abstractify_slow
    raise TypeError(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Cannot interpret value of type <class 'waymax.env.base_environment.BaseEnvironment'> as an abstract array; it does not have a dtype attribute

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/Desktop/workspaces/code/waymax/jingke/single_agent/rollout.py", line 85, in <module>
    rollout_output = jax_rollout(init_state, expert_actor, env, rng=jax.random.PRNGKey(0), rollout_num_steps=5)
TypeError: Cannot interpret value of type <class 'waymax.env.base_environment.BaseEnvironment'> as an abstract array; it does not have a dtype attribute
2023-12-18 19:58:22.882393: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:465] TfrtCpuClient destroyed.

Process finished with exit code 1

I want to know if my understanding of rollout is wrong? Is there a more efficient way to rollout?

All arguments to a jitted function need to be JAX arrays (or structures of arrays). The general fix would be to use functools.partial to fill in the non-JAX arguments first (they will be statically embedded into the computation graph), then use jax.jit on the remaining arguments.

In your case, you would need to do something like:

rollout_fn = functools.partial(_env.rollout, expert_actor=<>, env=<>, rollout_num_steps=2)
result = jax.jit(rollout_fn)(state, rng_key)