waymo-research/waymax

Questions about using distributed simulation

jkewang opened this issue · 0 comments

As a beginner, I had a lot of trouble modifying the parallel simulation process. When I set distributed=True, I can indeed receive a SimulationState with a shape of (gpus, batch, ...). However, because SimulationState is defined as a very complex pytree, this makes it difficult to use pmap and vmap to split state by dims, Making any custom functions difficult to parallelize.

For example some of my attempts:

jax_my_func = ...
vmap_jax_myfunc = jax.vmap(jax_my_func, in_axes=(None, 1))
pmap_jax_myfunc = jax.pmap(vmap_jax_myfunc, in_axes=(0, 0), axis_name='devices')
states, actions = ...
reward = pmap_jax_myfunc(states, actions)

then the pmap will fail because states' type is not supported for pmap.

And I also tried:

replicate_states = flax.jax_utils.replicate(states)
replicate_actions = flax.jax_utils.replicate(actions)
reward = pmap_jax_myfunc(replicate_states, replicate_actions)

It works for pmap and vmap, But I have to manually split states by devices in inner function, which is also a hard job for such a complex pytree.

Does the official have any examples or any tips for distributed simulation? Thank you so much!