/alae_tf2

This is a Python/Tensorflow 2.0 implementation of the Adversarial Latent AutoEncoders.

Primary LanguagePythonMIT LicenseMIT

Adversarial Latent Autoencoders, ALAE with TF2

Frédéric TOST

MNIST dataset, ConvNet implementation with Tensorflow 2

Generated images

   

ALAE

Content

This is a Python/Tensorflow 2.0 implementation of the Adversarial Latent AutoEncoders. See reference below:

  • Stanislav Pidhorskyi, Donald A. Adjeroh, and Gianfranco Doretto. Adversarial Latent Autoencoders. In Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR), 2020. [to appear]
preprint on arXiv: 2004.04467

MNIST dataset is used as a toy example. The Generator and E encoder are using Conv2D and Conv2DTranspose instead of a MLP (Multi-Layer Perceptron used in the paper). This gives better results but a longer training.

PyPI - Python Version License: MIT

Objective

The objective is to show how to easily implement the ALAE using the MNIST dataset and convolutional networks. Finding the hyperparameters such as learning rate of each optimizer is the most fastidious task.

Additional features

  • The use of Tensorflow 2.0 HParams Dashboard features allows to keep a trace of each run using different hyperparameters.

  • 3 losses time history
Loss Time history
Generator
Discriminator
Latent
  • the latent loss was splitted into Reconstruction loss and Kullback Leibler (KL) loss. KL loss is not used in the original paper, it seems to accelerate the convergence of the Generator/Discriminator.

    K_RECONST_KL = 1.0 # To use full reconstruction (alae_tf2.py)

To run the demo

To run the demo, you will need to have installed Tensorflow 2.0.0 or more recent (2.1.0, 2.2.0).

Run the demo

python alae_tf2.py

Repository structure

Path Description
alae_tf2.py Configure the hyperparameters and run the demo.
alae_tf2_helper.py Train the neural network. alae_helper is the class that define the losses and the train step function.
alae_tf2_models.py Models used in the demo, Encoders F & E, Generator and Discriminator (4 classes).
utils.py Useful functions to process and plot images (2 functions).

Authors