/saacjax

Primary LanguagePythonMIT LicenseMIT

WARNING: Rljax is currently in a beta version and being actively improved. Any contributions are welcome :)

Rljax

Rljax is a collection of RL algorithms written in JAX.

Setup

You can install dependencies simply by executing the following.

pip install -r requirements.txt
pip install -e .
git clone https://github.com/google-research/realworldrl_suite.git
cd realworldrl_suite
pip install -e .

To use GPUs or TPUs, follow instructions here.

Algorithm

Currently, following algorithms have been implemented.

Algorithm Action Vector State Pixel State PER[11] D2RL[15]
PPO[1] Continuous ✔️ - - -
DDPG[2] Continuous ✔️ - ✔️ ✔️
TD3[3] Continuous ✔️ - ✔️ ✔️
SAC[4,5] Continuous ✔️ - ✔️ ✔️
SAC+DisCor[12] Continuous ✔️ - - ✔️
TQC[16] Continuous ✔️ - ✔️ ✔️
SAC+AE[13] Continuous - ✔️ ✔️ ✔️
SLAC[14] Continuous - ✔️ - ✔️
DQN[6] Discrete ✔️ ✔️ ✔️ -
QR-DQN[7] Discrete ✔️ ✔️ ✔️ -
IQN[8] Discrete ✔️ ✔️ ✔️ -
FQF[9] Discrete ✔️ ✔️ ✔️ -
SAC-Discrete[10] Discrete ✔️ ✔️ ✔️ -

Example

All algorithms can be trained in a few lines of code.

Getting started

Here is a quick example of how to train DQN on CartPole-v0.

import gym

from rljax.algorithm import DQN
from rljax.trainer import Trainer

NUM_AGENT_STEPS = 20000
SEED = 0

env = gym.make("CartPole-v0")
env_test = gym.make("CartPole-v0")

algo = DQN(
    num_agent_steps=NUM_AGENT_STEPS,
    state_space=env.observation_space,
    action_space=env.action_space,
    seed=SEED,
    batch_size=256,
    start_steps=1000,
    update_interval=1,
    update_interval_target=400,
    eps_decay_steps=0,
    loss_type="l2",
    lr=1e-3,
)

trainer = Trainer(
    env=env,
    env_test=env_test,
    algo=algo,
    log_dir="/tmp/rljax/dqn",
    num_agent_steps=NUM_AGENT_STEPS,
    eval_interval=1000,
    seed=SEED,
)
trainer.train()
MuJoCo(Gym)

I benchmarked my implementations in some environments from MuJoCo's -v3 task suite, following Spinning Up's benchmarks (code). In TQC, I set num_quantiles_to_drop to 0 for HalfCheetath-v3 and 2 for other environments. Note that I benchmarked with 3M agent steps, not 5M agent steps as in TQC's paper.

DeepMind Control Suite

I benchmarked SAC+AE and SLAC implementations in some environments from DeepMind Control Suite (code). Note that the horizontal axis represents the environment step, which is obtained by multiplying agent_step by action_repeat. I set action_repeat to 4 for cheetah-run and 2 for walker-walk.

Atari(Arcade Learning Environment)

I benchmarked SAC-Discrete implementation in MsPacmanNoFrameskip-v4 from the Arcade Learning Environment(ALE) (code). Note that the horizontal axis represents the environment step, which is obtained by multiplying agent_step by 4.

Reference

[1] Schulman, John, et al. "Proximal policy optimization algorithms." arXiv preprint arXiv:1707.06347 (2017).

[2] Lillicrap, Timothy P., et al. "Continuous control with deep reinforcement learning." arXiv preprint arXiv:1509.02971 (2015).

[3] Fujimoto, Scott, Herke Van Hoof, and David Meger. "Addressing function approximation error in actor-critic methods." arXiv preprint arXiv:1802.09477 (2018).

[4] Haarnoja, Tuomas, et al. "Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor." arXiv preprint arXiv:1801.01290 (2018).

[5] Haarnoja, Tuomas, et al. "Soft actor-critic algorithms and applications." arXiv preprint arXiv:1812.05905 (2018).

[6] Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning." nature 518.7540 (2015): 529-533.

[7] Dabney, Will, et al. "Distributional reinforcement learning with quantile regression." Thirty-Second AAAI Conference on Artificial Intelligence. 2018.

[8] Dabney, Will, et al. "Implicit quantile networks for distributional reinforcement learning." arXiv preprint. 2018.

[9] Yang, Derek, et al. "Fully Parameterized Quantile Function for Distributional Reinforcement Learning." Advances in Neural Information Processing Systems. 2019.

[10] Christodoulou, Petros. "Soft Actor-Critic for Discrete Action Settings." arXiv preprint arXiv:1910.07207 (2019).

[11] Schaul, Tom, et al. "Prioritized experience replay." arXiv preprint arXiv:1511.05952 (2015).

[12] Kumar, Aviral, Abhishek Gupta, and Sergey Levine. "Discor: Corrective feedback in reinforcement learning via distribution correction." arXiv preprint arXiv:2003.07305 (2020).

[13] Yarats, Denis, et al. "Improving sample efficiency in model-free reinforcement learning from images." arXiv preprint arXiv:1910.01741 (2019).

[14] Lee, Alex X., et al. "Stochastic latent actor-critic: Deep reinforcement learning with a latent variable model." arXiv preprint arXiv:1907.00953 (2019).

[15] Sinha, Samarth, et al. "D2RL: Deep Dense Architectures in Reinforcement Learning." arXiv preprint arXiv:2010.09163 (2020).

[16] Kuznetsov, Arsenii, et al. "Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics." arXiv preprint arXiv:2005.04269 (2020).