/purejaxrl

Really Fast End-to-End Jax RL Implementations

Primary LanguagePythonApache License 2.0Apache-2.0

PureJaxRL (End-to-End RL Training in Pure Jax)

Code style: black Open In Colab

PureJaxRL is a high-performance, end-to-end Jax Reinforcement Learning (RL) implementation. When running many agents in parallel on GPUs, our implementation is over 1000x faster than standard PyTorch RL implementations. Unlike other Jax RL implementations, we implement the entire training pipeline in JAX, including the environment. This allows us to get significant speedups through JIT compilation and by avoiding CPU-GPU data transfer. It also results in easier debugging because the system is fully synchronous. More importantly, this code allows you to use jax to jit, vmap, pmap, and scan entire RL training pipelines. With this, we can:

  • 🏃 Efficiently run tons of seeds in parallel on one GPU
  • 💻 Perform rapid hyperparameter tuning
  • 🦎 Discover new RL algorithms with meta-evolution

For more details, visit the accompanying blog post: https://chrislu.page/blog/meta-disco/

This notebook walks through the basic usage: Open In Colab

CHECK OUT RESOURCES.MD to see github repos that are part of the Jax RL Ecosystem!

Performance

Without vectorization, our implementation runs 10x faster than CleanRL's PyTorch baselines, as shown in the single-thread performance plot.

Cartpole Minatar-Breakout

With vectorized training, we can train 2048 PPO agents in half the time it takes to train a single PyTorch PPO agent on a single GPU. The vectorized agent training allows for simultaneous training across multiple seeds, rapid hyperparameter tuning, and even evolutionary Meta-RL.

Vectorised Cartpole Vectorised Minatar-Breakout

Code Philosophy

PureJaxRL is inspired by CleanRL, providing high-quality single-file implementations with research-friendly features. Like CleanRL, this is not a modular library and is not meant to be imported. The repository focuses on simplicity and clarity in its implementations, making it an excellent resource for researchers and practitioners.

Installation

Install dependencies using the requirements.txt file:

pip install -r requirements.txt

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Example Usage

examples/walkthrough.ipynb walks through the basic usage. Open In Colab

examples/brax_minatar.ipynb walks through using PureJaxRL for Brax and MinAtar. Open In Colab

Related Work

Check out the list of RESOURCES to see libraries that are closely related to PureJaxRL!

The following repositories and projects were pre-cursors to purejaxrl:

Citation

If you use PureJaxRL in your work, please cite the following paper:

@article{lu2022discovered,
    title={Discovered policy optimisation},
    author={Lu, Chris and Kuba, Jakub and Letcher, Alistair and Metz, Luke and Schroeder de Witt, Christian and Foerster, Jakob},
    journal={Advances in Neural Information Processing Systems},
    volume={35},
    pages={16455--16468},
    year={2022}
}