/Tensorflow_WhatWhereAutoencoder

Stacked What-Where Auto-encoders implementation wiht Tensorflow

Primary LanguagePython

What-Where autoencoder. Tensorflow

This project contains Tensorflow implementation of Stacked What-Where Auto-encoders. Implementation uses transposed convolutions provided by tensorflow and custom upsampling and unpooling code.

Note: As of now, unpooling code is not working with Tensorflow versions 1.13 and 2.0.0-alpha (the latest versions in pip) due to a bug in scatter_nd. It has been fixed in nightly versions and should be released with any next Tensorflow package. Please, use one of the next tensorflow versions:

$ pip install tf-nightly-2.0-preview 
$ pip install tf-nightly-gpu-2.0-preview
$ pip install tf-nightly # 1.14 preview
$ pip install tensorflow-gpu==1.12.0

Features

Outputs

Run tensorboard for visualization:

tensorboard --logdir=./tmp/

Output example

Picture above shows output images for original mnist image (left), decoding of what-where autoecoder (center), decoding of convolutional autoencoder with naive upsampling (right) while using stride=7. Picture repeat the experiment of the original paper.

Model graph

Dependencies

  • Python 3.5
  • Tensorflow 1.0 with GPU support
  • Numpy
pip3 install tensorflow-gpu numpy

Running model

Running learning script:

python WWAE.py --batch_size=128 --max_epochs=2 --pool_size=7
python WWAE_keras.py --batch_size=128 --max_epochs=2 --pool_size=7
python WWAE_tf2.0.py