/vmp-for-svae

Variational Message Passing for Structured VAE (Code for ICLR 2018 paper)

Primary LanguagePython

vmp-for-svae

Variational Message Passing for Structured VAE (Code for the ICLR 2018 paper by Wu Lin, Nicolas Hubacher and Mohammad Emtiyaz Khan)

Getting Started

Before running our code, create a conda environment using the file environment.yml. To do so, open a terminal and run: conda env create -f environment.yml

Then, activate the created environment: source activate san-cpu-env

If you don't want to use conda, just make sure to use the libraries listed in environment.yml in their specified version (most importantly use TensorFlow version 1.3).

Please note that for simplicity, environment.yml only contains TensorFlow with CPU support. Follow this installation guide if you want to use a GPU-enabled version of TensorFlow.

Running the Code

Execute experiments.py to run our algorithm. Several options can be set at the beginning of this script. For instance it is possible to use multiple GPUs for training.

Then, the experimental setup can be defined: dataset, stepsize, neural network architecture, etc. One or multiple experiment configurations can be listed in the variable schedule and are executed consecutively.

The performance measured during these experiments is saved in a log directory (specified in variable log_dir). The training progress can be monitored using Tensorboard. In a terminal, run tensorboard --logdir=<path/to/log_dir> and open the returned link in a browser.

Plots

The plots in Figure 2 in the paper (see screenshot below) have been generated with the script visualisation/plots.py, the plots in Figure 3 with the script visualisation/visualise_sampled_distr.py. These plots can only be generated after the log files mentioned above have been generated.

Fig2

Fig3

Acknowledgements

  • Our code builds on the SVAE implementation by Johnson et. al. which is written in numpy and autograd. We have 'translated' parts of this code to Tensorflow.
  • To allow for multi-GPU training, we used the model replica approach explained here and implemented here by Norman Heckscher.
  • We tried to make our plots look nicer using this script by Bennett Kanuka.

Citing

If you use our code, please cite our ICLR paper. This is the Bibtex:

@inproceedings{
lin2018variational,
title={Variational Message Passing with Structured Inference Networks},
author={Wu Lin and Nicolas Hubacher and Mohammad Emtiyaz Khan},
booktitle={International Conference on Learning Representations},
year={2018},
url={https://openreview.net/forum?id=HyH9lbZAW},
}