/jaxrl3

JAXRL with support for Gymnasium and latest JAX

Primary LanguagePythonMIT LicenseMIT

jaxrl3

JAXRL3 is a fork of the amazing JAXRL2, adding support for gymnasium and the latest version of jax. DMC is realized via Shimmy. Focus is on online RL, hence DRQ and SAC is kept while the rest is cut for now.

Installation

Run

pip install --upgrade pip

pip install -e .

# either CPU
pip install --upgrade "jax[cpu]"
# or GPU
# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

See instructions for other versions of CUDA here.

Examples

Here.

Acknowledgements

Based on work by Ilya Kostrikov.

@misc{jaxrl,
  author = {Kostrikov, Ilya},
  doi = {10.5281/zenodo.5535154},
  month = {10},
  title = {{JAXRL: Implementations of Reinforcement Learning algorithms in JAX}},
  url = {https://github.com/ikostrikov/jaxrl2},
  year = {2022},
  note = {v2}
}