Code for the paper Neural Variational Gradient Descent.
The main framework used in this project is JAX (version jax==0.2.7
and jaxlib==0.1.57
)
See the installation instructions to install it with proper accelerator support.
Optionally (necessary for some of the experiments) install tensorflow_datasets==4.2.0
and tensorflow 2.0.
To clone this project and install it, first install JAX according to the instructions linked above. Then run
git clone git@github.com:langosco/neural-variational-gradient-descent.git
cd neural-variational-gradient-descent
pip install -e .
The implementation is contained in the nvgd/src
folder. The files are structured as follows:
distributions.py
: a set of classes that bundle together all attributes and methods associated with a probability distribution (e.g. mean, variance, sampling, computing the likelihood and loglikelihood).flows.py
: implements functions to simulate the particle dynamics, using the models inmodels.py
kernels.py
: a set of positive definite kernel functions.metrics.py
: utilities for computing metrics to track convergence to the target (e.g. MMD distance or mean squared error).models.py
: this is the heart of the project. Contains different models that each compute a single iteration of their respective particle dynamics.nets.py
: neural network architectures.plot.py
: utility functions for plotting.stein.py
: implementations of the Stein operator, (kernelized) Stein discrepancy, and associated methods.utils.py
: miscellaneous utility functions.
Many standard MCMC methods such as Hamiltonian Monte Carlo don't work well in settings with large data-sets and high-dimensional target posteriors with complicated dependencies. This is why usually simpler methods such as variational inference (VI) or stochastic gradient Langevin dynamics (SGLD) are applied to this type of problem (e.g. training a Bayesian neural network).
In 2016, Quang Liu and Dilin Wang proposed Stein variational gradient descent (SVGD), a new kind of inference method that rapidly became popular. SVGD transports a set of
A drawback of SVGD is that it is dependent on the choice of a kernel function. If this kernel is not chosen well, the method may converge badly or not at all. The goal of this project was to build an alternative to SVGD that does not depend on a choice of kernel. Here's a side-by-side visual comparison of our method ('neural gradient flow') and SVGD:
Here are the results (on the same task as in the gif) in a more interpretable form, plus a comparison with Langevin dynamics (smaller MMD is better). In this task and in others, our method outperforms or matches SVGD.