/rl_ssms

State Space Models for Reinforcement Learning in Tensorflow

Primary LanguageJupyter NotebookMIT LicenseMIT

Deep State Space Models for Reinforcement Learning in Tensorflow

Work in progress!!!

This is a repository implement and evaluate some different types of Deep State Space Models for Reinforcement Learning. The main inspiration comes from the paper Learning and Querying Fast Generative Models for Reinforcement Learning by L. Buesing et al at DeepMind.

The thesis Deep Latent Variable Models for Sequential Data by Marco Fraccaro is a very valuable source of information.

Short description of the main files in the repo:

data_handler_bouncing_balls.py:
Generates the bouncing ball dataset. Currently the behaviour is fully deterministic except for the initial position and velocity.

train_env_model.py:
Instantiates a sSSM model as described here and trains it on the bouncing ball dataset.

bouncing_ball_prediction.ipynb:
Uses the first 3 observations from sequences in the bouncing ball test set to predict the next 10 frames.