/svae

Tensorflow implementation of Spatial VAE via Matrix-Variate Normal Distributions

Primary LanguagePython

Spatial VAE via Matrix-Variate Normal Distributions

This is the tensorflow implementation of our recent work, "Spatial Variational Auto-Encoding via Matrix-Variate Normal Distributions". Please check the paper for details.

Introduction

In this work, we propose spatial VAEs that use latent variables as feature maps of larger size to explicitly capture spatial information. This is achieved by allowing the latent variables to be sampled from matrix-variate normal (MVN) distributions whose parameters are computed from the encoder network.

If you wish to cite our work, you can use the following bib for now.

@article{wang2017svae,
  title={Spatial Variational Auto-Encoding via Matrix-Variate Normal
Distributions},
  author={Zhengyang Wang and Hao Yuan and Shuiwang Ji},
  journal={arXiv preprint arXiv:1705.06821},
  year={2017}
}

Experimental results:

  1. CelebA dataset

image

  1. Cifar dataset

image

In both figures above, the first and second rows shows training images and images generated by the original VAEs. The remaining three rows are the results of naïve spatial VAEs, spatial VAEs via MVN distributions and spatial VAEs via low-rank MVN distributions, respectively.

For details of our work, please refer to Spatial Variational Auto-Encoding via Matrix-Variate Normal Distributions.

Datasets:

Our experiments are based on cifar 10 and cropped celebA datasets. We provide data reader for those two datasets. For celebA dataset, please convert it to h5 file first then call the data reader.

You can use other datasets such as The Street View House Numbers (SVHN) Dataset or LSUN Dataset as well. Just write a simple data reader file.

How to run it

  1. Clone or download this repository to your working directory.
  2. Get the datasets ready.
  3. Set related arguments in main.py. There are two models: vanilla VAE and our Spatial VAE. Set "model_name" as 'vanilla' to call tranditional VAE and set as 'low_rank' to call our Spatial VAE.
  4. Call python main.py or python main.py --action=train to train the model.
  5. If you wish to use "parzen window" to evaluate the model, set a checkpoint in arguments for the model to reload and then call python main.py --action=test.