SaxBI is a JAX implementation of likelihood-free simulation-based inference (sbi) methods. Currently, the two algorithms used are Sequential Neural Likelihood Estimation (SNLE) and Sequential Neural Ratio Estimation (SNRE). This package offers a simple, functional API for carrying out the approximate posterior inference.
- Flax-based autoregressive normalizing flows with affine, piecewise affine, and piecewise rational quadratic splines
- Flax-based classifiers with/out residual skip connections
- Hamiltonian Monte Carlo sampling with NUTS kernels implemented in Numpyro
- And more!
- Probably some bugs too... Let me know what you find 😅
saxbi
requires python 3.9 or higher. It can be easily installed from the repository's home directory with
python setup.py install
The main workhorse of this package is the pipeline
function which takes 5 required arguments: rng
, X_true
, get_simulator
, log_prior
, and sample_prior
. We recommend making a simulator.py
file from which the latter 4 of these can be imported. The pipeline
function then returns the flax model, its trained parameters, and samples from the final iteration of the posterior.
from saxbi import pipeline
from simulator import X_true, get_simulator, log_prior, sample_prior
rng = jax.random.PRNGKey(16)
model, params, Theta_post = pipeline(rng, X_true, get_simulator, log_prior, sample_prior)
The examples/
directory holds a few canonical examples from the literature to show off the syntax in greater detail.
SNLE
from Papamakarios G, Sterrat DC and Murray I Sequential Neural Likelihood.
SNRE
from Hermans J, Begy V, and Louppe G. Likelihood-free Inference with Amortized Approximate Likelihood Ratios.
- Add diagnostics (like MMD, ROC AUC)
- Add support for Mining Gold (i.e. using simulator derivatives to improve likelihood estimators)