Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend.
A quick demonstration of some of the most basic elements of SSM. Check out the example notebooks for more!
from ssm.hmm import GaussianHMM
import jax.random as jr
# create a true HMM model
hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(0))
states, data = hmm.sample(key=jr.PRNGKey(1), num_steps=500, num_samples=5)
# create a test HMM model
test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))
# fit it to our sampled data
log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")
# use your favorite venv system
conda env create -n ssm_jax python=3.9
conda activate ssm_jax
# in repo root directory...
pip install -r requirements.txt
.
├── docs # [documentation]
├── notebooks # [example jupyter notebooks]
├── ssm # [main code repository]
│ ├── hmm # hmm models
│ ├── factorial_hmm # factorial hmm models
│ ├── arhmm # arhmm models
│ ├── twarhmm # twarhmm models
│ ├── lds # lds models
│ ├── slds # slds models
│ ├── inference # inference code
│ ├── distributions # distributions (generally, extensions of tfp distributions)
└── tests # [tests]
├── [unit tests] # unit test files mirroring the structure of ssm directory
| ...
└── timing_comparisons # benchmarking code (including comparisons to SSM_v0)