Variational Message Passing for Structured VAE (Code for the ICLR 2018 paper by Wu Lin, Nicolas Hubacher and Mohammad Emtiyaz Khan)
Before running our code, create a conda environment using the file environment.yml
. To do so, open a terminal and run:
conda env create -f environment.yml
Then, activate the created environment:
source activate san-cpu-env
If you don't want to use conda, just make sure to use the libraries listed in environment.yml
in their specified version (most importantly use TensorFlow version 1.3).
Please note that for simplicity, environment.yml
only contains TensorFlow with CPU support. Follow this installation guide if you want to use a GPU-enabled version of TensorFlow.
Execute experiments.py
to run our algorithm. Several options can be set at the beginning of this script. For instance it is possible to use multiple GPUs for training.
Then, the experimental setup can be defined: dataset, stepsize, neural network architecture, etc. One or multiple experiment configurations can be listed in the variable schedule
and are executed consecutively.
The performance measured during these experiments is saved in a log directory (specified in variable log_dir
). The training progress can be monitored using Tensorboard. In a terminal, run tensorboard --logdir=<path/to/log_dir>
and open the returned link in a browser.
The plots in Figure 2 in the paper (see screenshot below) have been generated with the script visualisation/plots.py
, the plots in Figure 3 with the script visualisation/visualise_sampled_distr.py
. These plots can only be generated after the log files mentioned above have been generated.
- Our code builds on the SVAE implementation by Johnson et. al. which is written in numpy and autograd. We have 'translated' parts of this code to Tensorflow.
- To allow for multi-GPU training, we used the model replica approach explained here and implemented here by Norman Heckscher.
- We tried to make our plots look nicer using this script by Bennett Kanuka.
If you use our code, please cite our ICLR paper. This is the Bibtex:
@inproceedings{
lin2018variational,
title={Variational Message Passing with Structured Inference Networks},
author={Wu Lin and Nicolas Hubacher and Mohammad Emtiyaz Khan},
booktitle={International Conference on Learning Representations},
year={2018},
url={https://openreview.net/forum?id=HyH9lbZAW},
}