This repository contains code to implement adversarial autoencoder using Tensorflow.
Medium posts:
-
A Wizard's guide to Adversarial Autoencoders: Part 1. Autoencoders?
-
A Wizard's guide to Adversarial Autoencoders: Part 3. Disentanglement of style and content.
-
A Wizard's guide to Adversarial Autoencoders: Part 4. Classify MNIST using 1000 labels.
Install virtualenv and creating a new virtual environment:
pip install virtualenv
virtualenv -p /usr/bin/python3 aa
Install dependencies
pip3 install -r requirements.txt
Note:
- I'd highly recommend using your GPU during training.
tf.nn.sigmoid_cross_entropy_with_logits
has atargets
parameter which has been changed tolabels
for tensorflow version > r0.12.
The MNIST dataset will be downloaded automatically and will be made available
in ./Data
directory.
To train a basic autoencoder run:
python3 autoencoder.py --train True
- This trains an autoencoder and saves the trained model once every epoch
in the
./Results/Autoencoder
directory.
To load the trained model and generate images passing inputs to the decoder run:
python3 autoencoder.py --train False
Training:
python3 adversarial_autoencoder.py --train True
Load model and explore the latent space:
python3 adversarial_autoencoder.py --train False
Example of adversarial autoencoder output when the encoder is constrained to have a stddev of 5.
Matching prior and posterior distributions.
Distribution of digits in the latent space.
Training:
python3 supervised_adversarial_autoencoder.py --train True
Load model and explore the latent space:
python3 supervised_adversarial_autoencoder.py --train False
Example of disentanglement of style and content:
Training:
python3 supervised_adversarial_autoencoder.py --train True
Load model and explore the latent space:
python3 supervised_adversarial_autoencoder.py --train False
Classification accuracy for 1000 labeled images:
Note:
- Each run generates a required tensorboard files under
./Results/<model>/<time_stamp_and_parameters>/Tensorboard
directory. - Use
tensorboard --logdir <tensorboard_dir>
to look at loss variations and distributions of latent code. - Windows gives an error when
:
is used during folder naming (this is produced during the folder creation for each run).I would suggest you to remove the time stamp fromfolder_name
variable in theform_results()
function. Or, just dual boot linux!
Please share this repo if you find it helpful.