/jaxrl_m

Skeleton for scalable and flexible Jax RL implementations

Primary LanguagePythonMIT LicenseMIT

A JAX Backbone for RL projects

DOI

This project serves as a "central backbone" for an RL codebase, designed to accelerate prototyping and diagnosis of new algorithms (although it auxiliarily does contain reference implementations of SAC, CQL, IQL, BC). It is inspired greatly by Ilya Kostrikov's JaxRL codebase.

The primary goal of the codebase is to make ease of coding up a new algorithm: towards this goal, the primary philosophy is that

algorithms should be single-file implementations

This means that (almost) all components of the algorithm (from update rule to network choices to hyperparameter choices) are all contained in one file (e.g. see BC example or SAC example). This makes it easy to read and understand the algorithm, and also makes it easy to modify the algorithm to test out new ideas. The code is also designed to scale as easily as possible to multi-GPU / TPU setups, with simple abstractions for distributed training.

Installation

Requires jax, flax, optax, distrax, and optionally wandb for logging. Clone this repository and install it (e.g. pip install -e .) or add to python path.

Usage

The fastest way to understand how to use this skeleton is to see the reference SAC implementation:

Agent: sac.py

Launcher: run_mujoco_sac.py

Structure

The code contains the following files:

  • jaxrl_m.common: Contains the TrainState abstraction (a fork of Flax's TrainState class with some additional syntactic features for ease of use), and some other useful general utilities (target_update, shard_batch)
  • jaxrl_m.dataset: Contains the Dataset class (which can store and sample from buffers containing arbitrarily nested dictionaries) and an equivalent ReplayBuffer class
  • jaxrl_m.networks: Contains implementations of common RL networks (MLP, Critic, ValueCritic, Policy)
  • jaxrl_m.evaluation: Contains code for running evaluation episodes of agents (e.g. with the evaluate(policy, env) function)
  • jaxrl_m.wandb: Contains code for easily setting up Weights & Biases for experiments
  • jaxrl_m.typing: Useful type aliases
  • jaxrl_m.vision: vision.models contains common vision models (e.g. ResNet, ResNetV2, Impala), vision.data_augmentations contains common augmentations (e.g. random crop, random color jitter, gaussian blur)

Examples

Example implementations:

  1. SAC
  2. IQL
  3. Continuous BC
  4. Discrete BC

Example Launchers:

  1. Mujoco SAC
  2. D4RL IQL

Citation

If you use this codebase in an academic work, please cite

@software{jaxrl_minimal,
  author       = {Dibya Ghosh},
  title        = {dibyaghosh/jaxrl\_m},
  month        = April,
  year         = 2023,
  publisher    = {Zenodo},
  version      = {v0.1},
  doi          = {10.5281/zenodo.7958265},
  url          = {https://github.com/dibyaghosh/jaxrl_m}
}