/kalman-jax

Approximate inference for Markov Gaussian processes using iterated Kalman smoothing, in JAX

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

Note: kalman-jax is now obselete. A significantly improved version of this code is now available at https://github.com/AaltoML/BayesNewton/

kalman-jax

Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing. Developed and maintained by William Wilkinson. The Bernoulli likelihood was implemented by Paul Chang. We are based in Arno Solin's machine learning group at Aalto University, Finland.

This project aims to implement an XLA JIT compilable framework for inference in (non-conjugate) Markov Gaussian processes, with autodiff using JAX.

The methodology is outlined in the following paper:

  • W. Wilkinson, P. Chang, M. Riis Andersen, A. Solin. State Space Expectation Propagation: Efficient Inference Schemes for Temporal Gaussian Processes. International Conference on Machine Learning (ICML), 2020 [arXiv]

More details about the variational inference method are given in the following paper:

  • P. Chang, W. Wilkinson, M. E. Khan, A. Solin. Fast Variational Learning in State-Space Gaussian Process Models. International Workshop on Machine Learning for Signal Processing (MLSP), 2020 [arXiv]

If you use this code in your research, please cite the paper as follows:

@inproceedings{wilkinson2020,
  title={State Space Expectation Propagation: Efficient Inference Schemes for Temporal {G}aussian Processes},
  author={Wilkinson, William J. and Chang, Paul E. and Andersen, Michael Riis and Solin, Arno},
  booktitle={International Conference on Machine Learning},
  year={2020}
}

Spatio temporal GP classification

Getting started

  • Install the latest version of jax and jaxlib (see requirements.txt for full dependencies)
  • We have lots of demo notebooks which cover many different tasks and modelling scenarios.

Info

We combine two recent advances in the field of probabilistic machine learning:

  • Development of state space methods for linear-time approximate inference in Gaussian processes
  • The ability to JIT compile and autodiff through loops efficiently with JAX

Code structure

Each approximate inference algorithm will call the same underlying Kalman filter and smoother methods, and will be distinguished by the way in which the approximate likelihood terms are computed.

Approximate inference algorithms

  • PEP - power expectation propagation
  • EKF - extended Kalman filtering
  • UKF - unscented Kalman filtering
  • GHKF - Gauss-Hermite Kalman filtering
  • SLF - statistical linearisation filter
  • EKS - extended Kalman smoothing
  • UKS - Unscented Kalman smoothing
  • GHKS - Gauss-Hermite Kalman smoothing
  • EEP - Extended EP
  • SLEP - statistically linearised EP
  • UEP - Unscented EP
  • GHEP - Gauss-Hermite EP
  • PL - posterior linearisation
  • VI - variational inference (with natural gradients)
  • STVI - spatio-temporal variational inference
  • STEP - spatio-temporal expectation propagation
  • STKS - spatio-temporal iterated smoothers (E, U, GH)

Likelihoods

  • Gaussian
  • Poisson (log-Gaussian Cox process)
  • Logit (Bernoulli classification)
  • Probit (Bernoulli classification)
  • Heteroscedastic Noise
  • Product (audio amplitude demodulation)

Priors

  • Matern class
  • RBF
  • Cosine
  • Periodic
  • Quasi-periodic
  • Subband
  • Sum
  • Product
  • Independent (multiple stacked components)
  • Latent force models (linear)

License

This software is provided under the Apache License 2.0. See the accompanying LICENSE file for details.