/tf-dqn

TensorFlow implementation of Deep Q-learning

Primary LanguagePython

Build Status

TensorFlow DQN

TensorFlow implementation of Deep Q-learning. Implements:

Results

Testing on Breakout across three seeds:

A cherry-picked episode:

Setup

Install pipenv, then set up a virtual environment and install main dependencies with

$ pipenv sync

Usage

To a train a policy, use train.py.

train.py uses Sacred for managing configuration. To specify training options overriding the defaults in config.py, use with then a number of config=value strings. For example:

$ python3 -m dqn.train with atari_config env_id=PongNoFrameSkip-v4 dueling=False

TensorBoard logs will be saved in a subdirectory inside the automatically-created runs directory.

Implementation gotchas

A couple of details to be aware of with DQN:

  • The hard hyperparameters to get right are a) the ratio of training to environment interaction, and b) how often to update the target network. I used the setup from Stable Baselines' implementation.
  • As mentioned in the blog post accompanying the release of OpenAI Baselines, don't forget the Huber loss.

Also, a general point: be super careful about normalising observations twice. It turns out Baselines had a bug for several months because of this. We use TensorFlow assertion functions to make sure the observations have the right scale right at the point they enter the network.

Tests

To run smoke tests, unit tests and an end-to-end test on Cartpole, respectively:

$ python tests.py Smoke
$ python tests.py Unit
$ python tests.py EndToEnd