[Enhancement] Support for large gradient_steps in SAC
LabChameleon opened this issue · 2 comments
Description:
Using the Jax implementation of SAC with larger values of gradient_steps
, e.g. 1000, is very slow to compile. Consider
Lines 333 to 352 in b8dbac1
I think the problem lies in unrolling the loop over too many gradient steps. Removing line 334 for not jiting avoids the problem.
To Reproduce
from sbx import SAC
import gymnasium as gym
env = gym.make('Pendulum-v1')
model = SAC('MlpPolicy', env, verbose=1, gradient_steps=1000)
model.learn(100000)
Expected behavior
It should compile fast.
Potential Fix
I adjusted the implementation by moving all computations in the loop body of SAC._train
to a new jit'd function gradient_step
. Using this function in a JAX fori_loop
solves the issue and almost instantly compiles. If you agree with this I would propose a PR with my solution.
### System Info
- Describe how the library was installed (pip, docker, source, ...): pip
- sbx-rl version: 0.7.0
- Python version: 3.11
- Jax version: 0.4.14
- Gymnasium version: 0.29
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Hello,
this is actually a known issue...
I tried in the past to replace it (to have something similar to what DQN uses: https://github.com/araffin/sbx/blob/master/sbx/dqn/dqn.py#L162) but I didn't manage to get everything working as before (including speed of training loop once compiled).
However, if you managed (have both fast compilation time and fast runtime), I would be happy to receive a PR for it =)
Hi,
thanks for your reply! I was not aware that you already know the issue. I will have another in-depth look at this and see if my implementation actually offers any improvements over your existing approach. If it is the case I would be happy to make a PR :)