/SaxBI

JAX implementation of Sequential Neural Likelihood Estimation (SNLE) and Sequential Neural Ratio Estimation (SNRE) simulation-based inference algorithms

Primary LanguagePython

logo

SaxBI is a JAX implementation of likelihood-free simulation-based inference (sbi) methods. Currently, the two algorithms used are Sequential Neural Likelihood Estimation (SNLE) and Sequential Neural Ratio Estimation (SNRE). This package offers a simple, functional API for carrying out the approximate posterior inference.

The fully automated pipeline features:

  • Flax-based autoregressive normalizing flows with affine, piecewise affine, and piecewise rational quadratic splines
  • Flax-based classifiers with/out residual skip connections
  • Hamiltonian Monte Carlo sampling with NUTS kernels implemented in Numpyro
  • And more!
  • Probably some bugs too... Let me know what you find 😅

Installation

saxbi requires python 3.9 or higher. It can be easily installed from the repository's home directory with

python setup.py install

Basic Usage

The main workhorse of this package is the pipeline function which takes 5 required arguments: rng, X_true, get_simulator, log_prior, and sample_prior. We recommend making a simulator.py file from which the latter 4 of these can be imported. The pipeline function then returns the flax model, its trained parameters, and samples from the final iteration of the posterior.

from saxbi import pipeline
from simulator import X_true, get_simulator, log_prior, sample_prior

rng = jax.random.PRNGKey(16)

model, params, Theta_post = pipeline(rng, X_true, get_simulator, log_prior, sample_prior)

The examples/ directory holds a few canonical examples from the literature to show off the syntax in greater detail.

SBI Algorithm References

Sequential Neural Likelihood Estimation (SNLE)

Sequential Neural likelihood-to-evidence Ratio Estimation (SNRE)

Todo

  • Add diagnostics (like MMD, ROC AUC)
  • Add support for Mining Gold (i.e. using simulator derivatives to improve likelihood estimators)