/ssm-jax

Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend

Primary LanguageJupyter NotebookMIT LicenseMIT

SSM: Bayesian learning and inference for state space models

Integration Tests Unit Tests Documentation Status

Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend.

Example

A quick demonstration of some of the most basic elements of SSM. Check out the example notebooks for more!

from ssm.hmm import GaussianHMM
import jax.random as jr

# create a true HMM model
hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(0))
states, data = hmm.sample(key=jr.PRNGKey(1), num_steps=500, num_samples=5)

# create a test HMM model
test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))

# fit it to our sampled data
log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")

Installation for Development

# use your favorite venv system
conda env create -n ssm_jax python=3.9
conda activate ssm_jax

# in repo root directory...
pip install -r requirements.txt

Project Structure

.
├── docs                      # [documentation]
├── notebooks                 # [example jupyter notebooks]
├── ssm                       # [main code repository]
│   ├── hmm                       # hmm   models
│   ├── factorial_hmm             # factorial hmm models
│   ├── arhmm                     # arhmm models
│   ├── twarhmm                   # twarhmm models
│   ├── lds                       # lds   models
│   ├── slds                      # slds  models
│   ├── inference                 # inference code
│   ├── distributions             # distributions (generally, extensions of tfp distributions)
└── tests                     # [tests]
    ├── [unit tests]              # unit test files mirroring the structure of ssm directory
    |   ...
    └── timing_comparisons        # benchmarking code (including comparisons to SSM_v0)

Documentation

Click here for documentation