/Adversarial-Autoencoder

An adversarial autoencoder implementation in pytorch

Primary LanguagePythonMIT LicenseMIT

Adversarial-Autoencoder

A convolutional adversarial autoencoder implementation in pytorch using the WGAN with gradient penalty framework.

There's a lot to tweak here as far as balancing the adversarial vs reconstruction loss, but this works and I'll update as I go along.

The MNIST GAN seems to converge at around 30K steps, while CIFAR10 arguably doesn't output anything realistic ever (compared to ACGAN). Nonetheless it starts to looks ok at around 50K steps

The autoencoder components are able to output good reconstructions much faster than the GAN. ~10k steps on MNIST. The auto encoder is currently bad with CIFAR10 (under investigation)

Note

There is a lot here that I want to add and fix with regard to image generation and large scale training.

But I can't do anything until pytorch fixes these issues with gradient penalty here

MNIST Gaussian Samples (GAN) - 33k steps

output image

MNIST Reconstructions (AE) - 10k steps

output_image

CelebA 64x64 Gaussian Samples (GAN) - 50k steps

output_image

CIFAR10 Gaussian Samples (GAN) - 200k steps

output image

CIFAR10 Reconstructions (AE) - 80k steps

output image

clearly need to fix this

Requirements

  • pytorch 0.2.0
  • python 3 - but 2.7 just requires some simple modifications
  • matplotlib / numpy / scipy

Usage

To start training right away just run

start.sh

To train on MNIST

python3 train.py --dataset mnist --batch_size 50 --dim 32 -o 784

To train on CIFAR10

python3 train.py --dataset cifar10 --batch_size 64 --dim 32 -o 3072

Acknowledgements

For the wgan-gp components I mostly used caogang's nice implementation

TODO

Provide pretrained model files from on a google drive folder (push loader).