/dreamerv2

An implementation of DreamerV2 written in JAX, Haiku and Optax

Primary LanguagePythonMIT LicenseMIT

Fork of DreamerV2 JAX+Haiku

🚨 This fork makes the agent compatible with Optax for optimizing, as well as a little bit of refactoring. 🚨

Please see the original work https://github.com/kenjyoung/dreamerv2 for any details.

Installation

Follow jax installation, then:

pip install -r requirements.txt

Training

Run the following command for learning to play on MinAtar:

python dreamerv2.py