Playing Atari with Deep Reinforcement Learning
This repository contains code for training a neural network to play Atari games. All code is based on ideas of the DeepMind paper Playing Atari with Deep Reinforcement Learning and was created in the context of the course CS 294: Deep Reinforcement Learning, Fall 2017 of the Berkely University.
In addition to training a network locally, the repository makes it easy to train on a GPU EC2 instance in AWS.
Requirements
- Bash
- Python 3
- virtualenv
- terraform (for AWS deployment only)
General Instructions
All operations can be accessed through the go script inside the root directory. Running the script without parameters prints available operations.
./go
Usage: ./go clean | deploy | ip | tensorboard | ssh | run | sync | gpu-usage | tf
Training the Network
To start training, run
./go run src/train_dqn.py
The default game is Breakout and can be changed by modifying the file src/dqn/dqn_atari.py. All parameters of the neural network are periodically saved during training and can be found in the folder checkpoints/.
To observe how the network develops during training, run
./go tensorboard
Then visit the displayed URL which points to a local Tensorboard server which provides a nice graph showing the relationship between the number of training steps and the strength of the neural network in playing the game.
Loading a Pretrained Network
To load a pretrained model, run
./go run src/run_dqn.py ${path to your checkpoint}
where the last parameter points to a checkpoint that was created during a previous training session.
Training in AWS
To provision a ready-to-use GPU instance in AWS, first set the environment variables SSH_KEY, AWS_ACCESS_KEY, and AWS_SECRET_KEY
export SSH_KEY=~/.ssh/id_rsa
export AWS_ACCESS_KEY={your access key}
export AWS_SECRET_KEY={your secret key}
Then run
./go deploy
./go sync
The first command will use Terraform and Ansible to create a new EC2 instance based on the Amazon Deep Learning Base AMI. Note that you might need to request a limit increase for p2.xlarge instances first. The second command synchronizes the remote machine with the local one.
To login to the freshly created EC2 instance, run
./go ssh
Now you are logged in to a tmux environment which provides all the operations of this repository (e.g. run ./go run src/train_dqn.py to start training).
Terminate all AWS Resources after Training
Simply run
./go tf destroy
and follow the dialog.