This is a codebase that implements simple reinforcement learning algorithms in JAX. It also has support for several environments. The idea is to have solid single-file implementations of various RL algorithms for research use. This codebase contains both online and offline methods.
Online Algorithms Implemented:
- Proximal Policy Optimization (PPO):
algs_online/ppo.py
- Soft Actor-Critic (SAC):
algs_online/sac.py
- Twin Delayed DDPG (TD3):
algs_online/td3.py
Offline Algorithms Implemented:
- Behavior Cloning (BC):
algs_offline/bc.py
- Implicit Q-Learning (IQL):
algs_offline/iql.py
Environments Supported:
- (Online) Gym Mujoco Locomotion:
HalfCheetah-v2, CartPole-v1, etc
- (Online) Deepmind Control:
cheetah_run, pendulum_swingup, etc
- (Offline) D4RL Mujoco Locomotion:
halfcheetah-medium-expert-v2, etc
- (Offline) D4RL AntMaze + Goal Conditioned:
antmaze-large-diverse-v2, gc-antmaze-large-diverse-v2
- (Offline) ExORL:
exorl_cheetah_walk, etc
- See
envs/env_helper.py
for full list
For the cleanest installation, create a conda environment:
conda env create -f deps/environment.yml
You can also refer to the singularity script in deps/base_container.def
for full reproducability.
We've provided a set of stable results comparing each algorithm to a reference implementation. For full training curves, see the wandb reports for online results and the wandb reports for offline results.
You can reproduce these results using the commands available at run_baselines.py
.
The basic starting point is to run the individual file, e.g.
python algs_online/ppo.py --env_name walker_walk --agent.gamma 0.99
Offline Results
Env | Best Performance (ours) | Best Original Performance (reference paper) |
---|---|---|
exorl_cheetah_run | 257.5 (IQL-DDPG) | ~250 (TD3) source (exorl) |
exorl_walker_run | 471.9 (IQL-DDPG) | ~200 (TD3) source (exorl) |
halfcheetah-medium-expert-v2 | 83.8 (IQL) | 90.7 (TD3+BC) source (iql) |
walker2d-medium-expert-v2 | 106.8 (BC) | 110.1 (TD3+BC) source (iql) |
hopper-medium-expert-v2 | 98.9 (IQL) | 98.7 (CQL) source (iql) |
gc-antmaze-large-diverse-v2 | 52.5 (IQL) | 50.7 (IQL) source (hiql) |
gc-maze2d-large-v1 | 97.5 (IQL) | N/A |
Online Results
Env | Best Performance (ours) | Best Original Performance (reference paper) |
---|---|---|
HalfCheetah-v2 | 11029 (SAC) | 12138.8 (SAC) source (tianshou) |
Walker2d-v2 | 5101.8 (SAC-Tianshoulike) | 5007 (SAC)source (tianshou) |
Hopper-v2 | 2714.4 (REDQ) | 3542.2 (SAC)source (tianshou) |
cheetah_run | 918.9 (REDQ) | 800 (SAC) source (pytorch-sac) |
walker_run | 835.7 (TD3) | 700 (SAC) source (pytorch-sac) |
hopper_hop | 474.9 (TD3) | 210 (SAC) source (pytorch-sac) |
quadruped_run | 920.8 (TD3) | 700 (SAC) source (pytorch-sac) |
humanoid_run | 211.8 (REDQ) | 90 (SAC) source (pytorch-sac) |
pendulum_swingup | 790.2 (SAC) | 920 (SAC) source (pytorch-sac) |
This code is based largely off the jaxrl_m repo, and takes inspiration also from jaxrl and cleanrl.