WARNING: Rljax is currently in a beta version and being actively improved. Any contributions are welcome :)
Rljax is a collection of RL algorithms written in JAX.
You can install dependencies simply by executing the following.
pip install -r requirements.txt
pip install -e .
git clone https://github.com/google-research/realworldrl_suite.git
cd realworldrl_suite
pip install -e .
To use GPUs or TPUs, follow instructions here.
Currently, following algorithms have been implemented.
Algorithm | Action | Vector State | Pixel State | PER[11] | D2RL[15] |
---|---|---|---|---|---|
PPO[1] | Continuous | ✔️ | - | - | - |
DDPG[2] | Continuous | ✔️ | - | ✔️ | ✔️ |
TD3[3] | Continuous | ✔️ | - | ✔️ | ✔️ |
SAC[4,5] | Continuous | ✔️ | - | ✔️ | ✔️ |
SAC+DisCor[12] | Continuous | ✔️ | - | - | ✔️ |
TQC[16] | Continuous | ✔️ | - | ✔️ | ✔️ |
SAC+AE[13] | Continuous | - | ✔️ | ✔️ | ✔️ |
SLAC[14] | Continuous | - | ✔️ | - | ✔️ |
DQN[6] | Discrete | ✔️ | ✔️ | ✔️ | - |
QR-DQN[7] | Discrete | ✔️ | ✔️ | ✔️ | - |
IQN[8] | Discrete | ✔️ | ✔️ | ✔️ | - |
FQF[9] | Discrete | ✔️ | ✔️ | ✔️ | - |
SAC-Discrete[10] | Discrete | ✔️ | ✔️ | ✔️ | - |
All algorithms can be trained in a few lines of code.
Getting started
Here is a quick example of how to train DQN on CartPole-v0
.
import gym
from rljax.algorithm import DQN
from rljax.trainer import Trainer
NUM_AGENT_STEPS = 20000
SEED = 0
env = gym.make("CartPole-v0")
env_test = gym.make("CartPole-v0")
algo = DQN(
num_agent_steps=NUM_AGENT_STEPS,
state_space=env.observation_space,
action_space=env.action_space,
seed=SEED,
batch_size=256,
start_steps=1000,
update_interval=1,
update_interval_target=400,
eps_decay_steps=0,
loss_type="l2",
lr=1e-3,
)
trainer = Trainer(
env=env,
env_test=env_test,
algo=algo,
log_dir="/tmp/rljax/dqn",
num_agent_steps=NUM_AGENT_STEPS,
eval_interval=1000,
seed=SEED,
)
trainer.train()
MuJoCo(Gym)
I benchmarked my implementations in some environments from MuJoCo's -v3
task suite, following Spinning Up's benchmarks (code). In TQC, I set num_quantiles_to_drop to 0 for HalfCheetath-v3 and 2 for other environments. Note that I benchmarked with 3M agent steps, not 5M agent steps as in TQC's paper.
DeepMind Control Suite
I benchmarked SAC+AE and SLAC implementations in some environments from DeepMind Control Suite (code). Note that the horizontal axis represents the environment step, which is obtained by multiplying agent_step by action_repeat. I set action_repeat to 4 for cheetah-run and 2 for walker-walk.
Atari(Arcade Learning Environment)
I benchmarked SAC-Discrete implementation in MsPacmanNoFrameskip-v4
from the Arcade Learning Environment(ALE) (code). Note that the horizontal axis represents the environment step, which is obtained by multiplying agent_step by 4.
[1] Schulman, John, et al. "Proximal policy optimization algorithms." arXiv preprint arXiv:1707.06347 (2017).
[2] Lillicrap, Timothy P., et al. "Continuous control with deep reinforcement learning." arXiv preprint arXiv:1509.02971 (2015).
[3] Fujimoto, Scott, Herke Van Hoof, and David Meger. "Addressing function approximation error in actor-critic methods." arXiv preprint arXiv:1802.09477 (2018).
[4] Haarnoja, Tuomas, et al. "Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor." arXiv preprint arXiv:1801.01290 (2018).
[5] Haarnoja, Tuomas, et al. "Soft actor-critic algorithms and applications." arXiv preprint arXiv:1812.05905 (2018).
[6] Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning." nature 518.7540 (2015): 529-533.
[7] Dabney, Will, et al. "Distributional reinforcement learning with quantile regression." Thirty-Second AAAI Conference on Artificial Intelligence. 2018.
[8] Dabney, Will, et al. "Implicit quantile networks for distributional reinforcement learning." arXiv preprint. 2018.
[9] Yang, Derek, et al. "Fully Parameterized Quantile Function for Distributional Reinforcement Learning." Advances in Neural Information Processing Systems. 2019.
[10] Christodoulou, Petros. "Soft Actor-Critic for Discrete Action Settings." arXiv preprint arXiv:1910.07207 (2019).
[11] Schaul, Tom, et al. "Prioritized experience replay." arXiv preprint arXiv:1511.05952 (2015).
[12] Kumar, Aviral, Abhishek Gupta, and Sergey Levine. "Discor: Corrective feedback in reinforcement learning via distribution correction." arXiv preprint arXiv:2003.07305 (2020).
[13] Yarats, Denis, et al. "Improving sample efficiency in model-free reinforcement learning from images." arXiv preprint arXiv:1910.01741 (2019).
[14] Lee, Alex X., et al. "Stochastic latent actor-critic: Deep reinforcement learning with a latent variable model." arXiv preprint arXiv:1907.00953 (2019).
[15] Sinha, Samarth, et al. "D2RL: Deep Dense Architectures in Reinforcement Learning." arXiv preprint arXiv:2010.09163 (2020).
[16] Kuznetsov, Arsenii, et al. "Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics." arXiv preprint arXiv:2005.04269 (2020).