probml/dynamax

add support for HMM inference with sparse transition matrices

murphyk opened this issue ยท 6 comments

+1. Lots of models have constrained transition matrices (e.g. left-to-right HMMs and change-point models).

๐Ÿ™‹ I can start looking at this !

I've started just by using the Gaussian HMM example from the docs.

import jax.numpy as jnp
import jax.random as jr
from dynamax.hidden_markov_model import GaussianHMM

num_states = 6
num_emissions = 35

# Construct the HMM
hmm = GaussianHMM(num_states, num_emissions)

# Specify parameters of the HMM
initial_probs = jnp.ones(num_states) / num_states
transition_matrix = 0.8 * jnp.eye(num_states) + jnp.diag(
    jnp.tile(0.2, num_states - 1), k=1)

emission_means = jnp.column_stack([
    jnp.cos(jnp.linspace(0, 2 * jnp.pi, num_states + 1))[:-1],
    jnp.sin(jnp.linspace(0, 2 * jnp.pi, num_states + 1))[:-1],
    jnp.zeros((num_states, num_emissions - 2)),
    ])
emission_covs = jnp.tile(
    0.1**2 * jnp.eye(num_emissions),
    (num_states, 1, 1))

# Initialize the parameters struct with known values    
params, _ = hmm.initialize(
    initial_probs=initial_probs,
    transition_matrix=transition_matrix,
    emission_means=emission_means,
    emission_covariances=emission_covs)
true_states, emissions = hmm.sample(params, jr.PRNGKey(42), 100)

posterior = hmm.smoother(params, emissions)

This MWE runs without any errors, suggesting that sparse transition matrices aren't the issue per se (since this one is):

DeviceArray([[0.8, 0.2, 0. , 0. , 0. , 0. ],
             [0. , 0.8, 0.2, 0. , 0. , 0. ],
             [0. , 0. , 0.8, 0.2, 0. , 0. ],
             [0. , 0. , 0. , 0.8, 0.2, 0. ],
             [0. , 0. , 0. , 0. , 0.8, 0.2],
             [0. , 0. , 0. , 0. , 0. , 0.8]], dtype=float32)

I can create an error by modifying the transition matrix to add a new, "dummy absorbing" state like so:

# Set up transition matrix with final dummy-absorbing state
transition_matrix = 0.8 * jnp.eye(num_states + 1) + jnp.diag(
    jnp.tile(0.2, num_states), k=1)
transition_matrix = transition_matrix.at[-1, -1].set(1)
transition_matrix
DeviceArray([[0.8, 0.2, 0. , 0. , 0. , 0. , 0. ],
             [0. , 0.8, 0.2, 0. , 0. , 0. , 0. ],
             [0. , 0. , 0.8, 0.2, 0. , 0. , 0. ],
             [0. , 0. , 0. , 0.8, 0.2, 0. , 0. ],
             [0. , 0. , 0. , 0. , 0.8, 0.2, 0. ],
             [0. , 0. , 0. , 0. , 0. , 0.8, 0.2],
             [0. , 0. , 0. , 0. , 0. , 0. , 1. ]], dtype=float32)

But that might be a bit too far afield from this original issue to discuss here !

Please let me know if I misunderstood the original posting, too ๐Ÿ™

What I had in mind is to exploit sparsity to speedup the K^2 computation at each step of forwards-backwards.
Currently we just use alpha(t) = A*alpha(t-1), and ignore structure in A.
My idea was to use
https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html
to exploit the sparsity in A.
(The linked JSL code says it works with jax.experiemental.sparse, but there are no demos
or units tests, so I am not sure that is true... Besides, JSL is deprecated, we only want to support dynamax.)

An alternative approach is to implement the algorithms in the paper below,
which only work for certain banded transition matrices, which can arise from discretizing
an underlying continuous system.

@inproceedings{Felzenszwalb03,
title = {{Fast Algorithms for Large State Space HMMs with
Applications to Web Usage Analysis}},
booktitle = nips,
year = 2003,
author = "P. Felzenszwalb and D. Huttenlocher and J. Kleinberg"
}

Ah, there's a slight confusion then. One issue is a bug that the current message passing code is returning NaNs in some cases, as @emdupre showed above. Another is a feature request to support the experimental sparse library for faster matrix-vector multiplies.

@emdupre, why don't you create a new issue for the bug in the current message passing code and reference this one.

Thank you both, and sorry for the noise ! I've opened #290 for that discussion ๐Ÿ™‡