/tf_mvg

Multivariate Gaussian distributions for Tensorflow.

Primary LanguagePythonMIT LicenseMIT

Multivariate Gaussian distributions for Tensorflow

This repository contains parts of the implementation code for the projects 'Structured Uncertainty Prediction Networks' (CVPR 2018) and 'Training VAEs Under Structured Residuals' (arxiv 2018).


Papers

Structured Uncertainty Prediction Networks
Garoe Dorta 1,2, Sara Vicente 2, Lourdes Agapito 3, Neill D.F. Campbell 1 and Ivor Simpson 2
1 University of Bath, 2 Anthropics Technology Ltd., 3 University College London
IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018

Training VAEs Under Structured Residuals
Garoe Dorta 1,2, Sara Vicente 2, Lourdes Agapito 3, Neill D.F. Campbell 1 and Ivor Simpson 2
1 University of Bath, 2 Anthropics Technology Ltd., 3 University College London
arXiv e-prints, 2018


Dependencies


Detailed description

This code provides a collection of Multivariate Gaussian distributions parametrized with log diagonals. This parametrization leads to more stable computations of log probabilities. The distributions are subclasses of tensorflow_distributions and can directly replace any Multivariate Gaussian distribution class.

In practice this means changing the activation of the layer that predicts the covariance matrix from softplus to no activation. For example in a dense layer setting

import tensorflow_probability as tfp
tf_dist = tensorflow_probability.distributions
import mvg_distributions as mvg_dist

n = # ... Data dimensionality
loc = # ... The predicted means
h = # ... Tensor of a hidden layer in the network

# Tensorflow probability approach
diag_covariance = keras.layers.Dense(n, activation=tf.nn.softplus)(h)
softplus_mvg = tfp.distributions.MultivariateNormalDiag(loc=loc, scale_diag=tf.sqrt(diag_covariance))

# mvg_distributions approach
log_diag_covariance = keras.layers.Dense(n, activation=None)(h)
log_mvg = mvg_dist.MultivariateNormalDiag(loc=loc, log_diag_covariance=log_diag_covariance)

The provided distributions are

  • MultivariateNormalDiag: for diagonal covariance matrices
  • MultivariateNormalChol: for Cholesky covariance matrices
  • MultivariateNormalPrecCholFilters: for sparse Cholesky precision matrices
  • MultivariateNormalPrecCholFiltersDilation: for sparse Cholesky precision matrices with dilated sparsity pattern
  • IsotropicMultivariateNormal: N(0,I) distribution, useful for numerically stable KL divergence with MultivariateNormalDiag
  • CholeskyWishart: a Cholesky-Whistart distribution, i.e. a distribution over Cholesky matrices
  • Gamma: a Gamma distribution that evaluates probabilities on log_values and samples log_values, useful for setting a prior on the log_diag_precision argument of a MultivariateNormalDiag distribution

Examples

examples/autoencoder_mvg_chol_filters.py shows how the use MultivariateNormalPrecCholFilters in an autoencoder setting, which is a demonstration of the work in Structured Uncertainty Prediction Networks

examples/autoencoder_mvg_diag.py shows how the use MultivariateNormalDiag in the same setting as the previous example.

kl_diag_isotropic.py shows how to use IsotropicMultivariateNormal and MultivariateNormalDiag to compute
KL(N(mu, sigma I) || N(0, I)), which is common in VAE networks.

kl_chol_diag.py and log_prob_chol_filters.py contain additional simple examples of KL divergences and log prob evaluations.


If this work is useful for your research, please cite our papers.