probml/dynamax

NaNs returned by`lgssm_posterior_sample`

calebweinreb opened this issue · 1 comments

We have been trying to incorporate dynamax into jax-moseq, a tool for unsupervised analysis of animal behavior. Specifically, we would like to replace our custom Kalman sampling code with the lgssm_posterior_sample method in dynamax. @ezhang94 has already done all the heavy lifting and tested it on some small-scale examples. However we are still getting all-NaN outputs for more realistically-sized datasets.

It seems like the problem can be solved by adding a small amount to the diagonal of the posterior covariance during each backward sampling step. Below is a brief recipe to reproduce the issue and a diagnosis of where the NaNs first appear.

  • The issue can be reproduced by running the keypoint-moseq tutorial, using the version of jax-moseq as of this commit. At some point during fitting, lgssm_posterior_sample returns all NaNs for some of the dataset.
  • The NaNs first appear during the forward filtering pass. This can be solved by forcing the output of the conditioning function to be symmetric. I recently submitted a separate issue and PR that implements this change.
  • NaNs still appear even after the above fix, however, but now during the backward sampling pass. They are rare, and seem to be caused by sharp discontinuities in the emissions. Once a NaN appears though, it is propagated through the rest of the backward pass.
  • These sampling NaNs specifically appear during the MVN sampling step when the covariance is non-PSD (min eigenvalue < -1e-4). The problem can be solved by padding the diagonal of the covariance matrix before passing it to the MVN sampler.

Thanks for digging into this and finding/fixing these issues. I've merged your PRs. Hope we can get jax-moseq to work with dynamax!