This Deep Reinforcement Learning repository contains the most prominent On-Policy Policy Gradient Algorithms. All algorithms are implemented in JAX. Our implementations are based on Brax's implementation of PPO. We use Brax's logic for policy networks and distributions and Stable Baselines3's environment infrastructure to create batched environments. Inspired by CleanRL, we provide all algorithm logic including hyperparameters in a single file. However, for efficiency we have joint files for creating networks and distributions.
We implemented the following algorithms in JAX:
- REINFORCE
- Advantage Actor-Critic (A2C)
- Trust Region Policy Optimization (TRPO)
- Proximal Policy Optimization (PPO)
- V-MPO*
You can read more about these algorithms in our upcoming comprehensive overview of Policy Gradient Algorithms.
*on-policy variant of Maximum a Posteriori Policy Optimization (MPO)
We report the performance of our implementations on common MuJoCo environments (v4), interfaced through Gymnasium.
Prerequisites:
- Tested with Python ==3.11.6
- See requirements.txt for further dependencies (Note that that file bloated, not all libraries are actually needed.).
To run the algorithms locally, simply run the respective python file:
python ppo.py
If you use this repository in your work or find it useful, please cite our paper:
@article{lehmann2024definitive,
title={The Definitive Guide to Policy Gradients in Deep Reinforcement Learning: Theory, Algorithms and Implementations},
author={Matthias Lehmann},
year={2024},
eprint={2401.13662},
archivePrefix={arXiv},
primaryClass={cs.LG}
}