PGMax
PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.
- General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
- LBP in JAX: PGMax generates pure JAX functions implementing LBP for a
given factor graph. The generated pure JAX functions run on modern accelerators
(GPU/TPU), work with JAX transformations
(e.g.
vmap
for processing batches of models/samples,grad
for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.
See our companion paper for more details.
PGMax is under active development. APIs may change without notice, and expect rough edges!
Installation | Getting started
Installation
Install from PyPI
pip install pgmax
Install latest version from GitHub
pip install git+https://github.com/deepmind/PGMax.git
Developer
While you can install PGMax in your standard python environment, we strongly recommend using a Python virtual environment to manage your dependencies. This should help to avoid version conflicts and just generally make the installation process easier.
git clone https://github.com/deepmind/PGMax.git
cd PGMax
python3 -m venv pgmax_env
source pgmax_env/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
python3 setup.py develop
Install on GPU
By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.
Getting Started
Here are a few self-contained Colab notebooks to help you get started on using PGMax:
- Tutorial on basic PGMax usage
- LBP inference on Ising model
- Implementing max-product LBP for Recursive Cortical Networks
- End-to-end differentiable LBP for gradient-based PGM training
- 2D binary deconvolution
- Alternative inference with Smooth Dual LP-MAP
Citing PGMax
Please consider citing our companion paper
@article{zhou2022pgmax,
author = {Zhou, Guangyao and Dedieu, Antoine and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
journal = {arXiv preprint arXiv:2202.04110},
year={2022}
}
and using the DeepMind JAX Ecosystem citation if you use PGMax in your work.
Note
This is not an officially supported Google product.