We provide JAX implementations of three simulation-based inference methods, including pseudo-likelihood inference. The following algorithms are provided in this repository:
- Pseudo-Likelihood Inference (PLI, Ours) [1]
- Sequential Monte Carlo Approximate Bayesian Computation (SMC-ABC) [2]
- Automatic Posterior Transformation (APT) [3]
Run pip install -e .
to install the package. We tested the code with python 3.9.13.
We highly recommend GPU support for JAX by following the instructions.
For example, to install jax with cuda 12 or 11, either run
pip install --upgrade "jaxlib[cuda12_pip]"==0.4.1 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
or
pip install --upgrade "jaxlib[cuda11_pip]"==0.4.1 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
before or after installing pli.
To start an experiment, simply run:
python scripts/run_sbi_experiment.py task=gaussian-linear method=pli method/model=gaussian
Here, we run a Gaussian model with PLI on the simple Gaussian-location task. To keep track of the experimental settings, we use hydra.
For a comprehensive list of all available configuration parameters,
simply run the script with a --help
argument
python scripts/run_sbi_experiment.py --help
== Configuration groups ==
Compose your configuration from those groups (group=option)
method: abc, pli, snpe
method/model: conditioned-nsf, gaussian, maf, nsf, pmc
method/pseudo_likelihood: exponential-kernel, uniform-kernel
method/statistical_distance: mmd, wasserstein
task: furuta, gaussian-linear, gaussian-mixture, sir, slcp
task/posterior_sampling: sbibm
In particular, change the following settings to repeat the experiments from the paper.
python scripts/run_sbi_experiment.py \
task=<TASK>
task.n_train_data=<N> \
method=<METHOD>
method/model=<MODEL>
The choice of the statistical distance can be set with method/statistical_distance=<SD>
.
The remaining hyperparameter settings are already set.
As an example, to run PLI with MMD on the SLCP task for N=100 reference observations, run
python scripts/run_sbi_experiment.py task=slcp \
task.n_train_data=100 \
method=pli \
method/model=nsf
method/statistical_distance=mmd
In addition, we provide a script to create posterior samples that has been used to produce the reference posterior samples when the posterior is not analytically available.
python scripts/create_posterior_samples.py task=<TASK> task.n_train_data=<N>
Although, we provide posterior samples in data
, this script can be used to generate data for different seeds.
The experiments are logged with wandb.
Before running an initial experiment, run wandb login
and validate your profile
with the authorization key provided from the website.
You can switch off logging with wandb in config.yaml
by toggling the saving flag.
If you find the work useful, please consider citing the paper:
@inproceedings{gruner2023pli
author = "Gruner, T. and Belousov, B. and Muratore, F. and Palenicek, D. and Peters, J.",
year = "2023",
title = "Pseudo-Likelihood Inference",
booktitle = "Advances in Neural Information Processing Systems (NIPS / NeurIPS)",
}
Our density estimators are based on distrax. We provide a fork of distrax that includes support for conditional density estimators.