/PolicyGradientsJax

On-Policy Policy Gradient Algorithms in JAX

Primary LanguagePythonMIT LicenseMIT

On-Policy Policy Gradient Algorithms in JAX

This Deep Reinforcement Learning repository contains the most prominent On-Policy Policy Gradient Algorithms. All algorithms are implemented in JAX. Our implementations are based on Brax's implementation of PPO. We use Brax's logic for policy networks and distributions and Stable Baselines3's environment infrastructure to create batched environments. Inspired by CleanRL, we provide all algorithm logic including hyperparameters in a single file. However, for efficiency we have joint files for creating networks and distributions.

Algorithms

We implemented the following algorithms in JAX:

You can read more about these algorithms in our upcoming comprehensive overview of Policy Gradient Algorithms.

*on-policy variant of Maximum a Posteriori Policy Optimization (MPO)

Benchmark Results

We report the performance of our implementations on common MuJoCo environments (v4), interfaced through Gymnasium.

Get started

Prerequisites:

  • Tested with Python ==3.11.6
  • See requirements.txt for further dependencies (Note that that file bloated, not all libraries are actually needed.).

To run the algorithms locally, simply run the respective python file:

python ppo.py

Citing PolicyGradientsJax

If you use this repository in your work or find it useful, please cite our paper:

@article{lehmann2024definitive,
      title={The Definitive Guide to Policy Gradients in Deep Reinforcement Learning: Theory, Algorithms and Implementations}, 
      author={Matthias Lehmann},
      year={2024},
      eprint={2401.13662},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}