/draw

TensorFlow Implementation of "DRAW: A Recurrent Neural Network For Image Generation"

Primary LanguagePythonApache License 2.0Apache-2.0

draw

TensorFlow implementation of DRAW: A Recurrent Neural Network For Image Generation on the MNIST generation task.

With Attention Without Attention

Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible.

For a gentle walkthrough through the paper and implementation, see the writeup here: http://blog.evjang.com/2016/06/understanding-and-implementing.html.

Usage

python draw.py --data_dir=/tmp/draw downloads the binarized MNIST dataset to /tmp/draw/mnist and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to /tmp/draw/draw_data.npy

You can visualize the results by running the script python plot_data.py <prefix> <output_data>

For example,

python myattn /tmp/draw/draw_data.npy

To run training without attention, do:

python draw.py --working_dir=/tmp/draw --read_attn=False --write_attn=False

Restoring from Pre-trained Model

Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in draw.py and editing the path to your checkpoint file as needed. Save electricity!

saver.restore(sess, "/tmp/draw/drawmodel.ckpt")

This git repository contains the following pre-trained in the data/ folder:

Filename Description
draw_data_attn.npy Training outputs for DRAW with attention
drawmodel_attn.ckpt Saved weights for DRAW with attention
draw_data_noattn.npy Training outputs for DRAW without attention
drawmodel_noattn.ckpt Saved weights for DRAW without attention

These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU.

Useful Resources