This repository contains the official implementation of our paper "Deterministic training of generative autoencoders using invertible layers".
In this work, we provide an exact likelihood alternative to the variational training of generative autoencoders. This is achieved while leaving complete freedom in the choice of encoder, decoder and prior architectures, making our approach a drop-in replacement for the training of existing VAEs and VAE-style models. We show that the approach results in strikingly higher performance than architecturally equivalent VAEs in term of log-likelihood, sample quality and denoising performance.
In notebook.ipynb
we give a short tutorial on how to initialize and train an AEF model, and compare samples generated by an AEF to samples generated by a VAE with an equivalent architecture.
To run experiments we provide a command line interface with the file main_cli.py
, or wandb_cli.py
which uses wandb to save experiment details. To train an AEF with a center mask on the MNIST dataset with a latent dimensionality of 2, run:
./main_cli.py --model aef-center --dataset mnist --latent-dims 2
To reproduce the experiments on CelebA-HQ resized to 64x64 with a latent dimensionality of 128, run:
./main_cli.py --model aef-linear --architecture big --posterior-flow maf --prior-flow maf --dataset celebahq64 --latent-dims 128 --iterations 1000000 --lr 1e-4 --batch-size 16 --early-stopping 100000 --data-dir [celebahq64-folder]
./main_cli.py --model vae --architecture big --posterior-flow iaf --prior-flow maf --dataset celebahq64 --latent-dims 128 --iterations 1000000 --lr 1e-4 --batch-size 16 --early-stopping 100000 --data-dir [celebahq64-folder]
For more details, please consult main_cli.py
.
For the CelebA-HQ experiments we used the 'data128x128.zip' file found here. It can be resized using
data/process_celebahq.py --data-dir "download_folder/data128x128" --output-folder "celebahq64" --dimension 64
for a size of 64x64.
MNIST | FashionMNIST |
---|---|
KMNIST | CelebA-HQ |
This implementation uses parts of the code from the following Github repositories: nflows, rectangular-flows, pytorch-fid as described in our code.