/coix

Inference Combinators in JAX

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

coix

Unittests Documentation Status PyPI version

Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators (Stites and Zimmermann et al., 2021), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box.

Coix is a lightweight framework which includes the following main components:

  • coix.api: Implementation of the program combinators.
  • coix.core: Basic program transformations which are used to modify behavior of a stochastic program.
  • coix.loss: Common objectives for variational inference.
  • coix.algo: Example inference algorithms.

Currently, we support numpyro and oryx backends. But other backends can be easily added via the coix.register_backend utility.

This is not an officially supported Google product.

Installation

To install Coix, you can use pip:

pip install coix

or you can clone the repository:

git clone https://github.com/jax-ml/coix.git
cd coix
pip install -e .[dev,doc]

Many examples would run faster on accelerators. You can follow the JAX installation instruction for how to install JAX with GPU or TPU support.