araffin/sbx

Mujoco XLA - MJX Integration

matinmoezzi opened this issue · 1 comments

As the biggest bottleneck of the training performance of SB3 is the environment, I am considering integrating SB3 with Mujoco XLA which is Mujoco written in Jax. Would this integration increase the performance? Currently, Mujoco XLA is released with huge performance improvement with Brax, including RL algorithms in JAX. Is SBX fully written in JAX?

Hello,

As the biggest bottleneck of the training performance of SB3 is the environment

I would actually disagree with this statement.
The main reason SBX is much faster than SB3 PyTorch is because the bottleneck was the gradient update.

Would this integration increase the performance?

It might but first you need to be sure where is the bottleneck and that you have optimized the parameters of SBX because considering faster env.

Is SBX fully written in JAX?

It is not, it still uses numpy/pytorch for the rollout/replay buffer.
The gradient updates are.