
POPGym Library in JAX

Primary LanguagePythonApache License 2.0Apache-2.0

POPJym logo

POPJym: Partially Observable Process Gym in JAX

POPJym is POPGym in JAX. Original POPGym Paper can be found here. The Structured State Space Models for In-Context Reinforcement Learning paper found here. The original code is from this project and has been cleaned and formatted by Edan Toledo (and he added the cool logo! -- thanks for the help!).

Quickstart Install

pip install popjym

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

For e.g.

pip install "jax[cuda12_pip]==0.4.7" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Quickstart Usage

import jax
import popjym
seed = jax.random.PRNGKey(0)
env, env_params = popjym.make(env_name)

env.reset(seed, env_params)

env.step(seed, state, action)


Please follow the coding style by using pre-commit.

pip install pre-commit
pre-commit install


If used in your work, please cite a) the original POPGym paper and b) the Structured State Space Models for In-Context Reinforcement Learning paper:

title={{POPG}ym: Benchmarking Partially Observable Reinforcement Learning},
author={Steven Morad and Ryan Kortvelesy and Matteo Bettini and Stephan Liwicki and Amanda Prorok},
booktitle={The Eleventh International Conference on Learning Representations},
  title={Structured State Space Models for In-Context Reinforcement Learning},
  author={Lu, Chris and Schroecker, Yannick and Gu, Albert and Parisotto, Emilio and Foerster, Jakob and Singh, Satinder and Behbahani, Feryal},
  journal={arXiv preprint arXiv:2303.03982},