probml/dynamax

refactor LGSMM class into subclasses

Closed this issue · 5 comments

Consider the affine LGSSM:
Screen Shot 2022-08-12 at 11 25 21 AM

Suppose we keep the inference code as is.
This code is purely functional, and takes in an LGSSMParams dataclass. Having support for the offset is useful for some of the posterior lienarization code Peter has written. However, I propose we simplify the model code so it does not include an explicit b or d bias term. If the user wants this, they have to provide their own input vector of ones, and modify the B and D matrices. This will simplify the MAP code and Gibbs sampling/ VB code.

Furthemore I propose we make various specializations of lgsmm/models/LinearGaussianSSM class. We remove all the EM code from the parent class (no e step, no m step), so the parent class only supports inference. Then we create 3 child classes:

  • The LGSMM_MLE class has no priors and just computes the MLE (see 29.8.1 of book2 for the M step).
  • The LGSSM_MNIW class uses the fully conjugate prior in sec 29.8.4.1, MNIW(Q, [F,B]) * MNIW(R, [H,D]) (see below)
  • The LGSSM_MNIW_STS class uses IW(Q) * MNIW(R, [D]) with B=0 and F, H fixed (for STS models, sec 29.12)

Screen Shot 2022-08-12 at 11 29 31 AM

In the future we might also want to create the following additional child classes:

  • LGSSM_NIW_SpikeSlab_STS class with a prior of the form IW(Q) * IW(R) * SpikeSlab(D).
  • LGSMM_MNIW_Identifiable, with a prior of the form N(F) * N(B) * MNIW(R, [H,D]) with Q=I fixed (for identifiability), and N(F) a prior that ensures F has unit eigenvalues, so the system is dynamically stable (see sec 28.14.3)

I'd like to start by modifying the parent class LGSSM and the class LGSSM_MNIW, based on the code I already have.
If I understand correctly, at this stage we only modify the parameter structure in models, (i.e., combine B,b and D,d) but pass the current parameter class (separate B, b and D, d) when we call inference algorithms?

yes, we keep the inference code as-is.

  • Rewrite the base class LinearGaussianSSM that has no learning (EM) method
  • Rewrite the subclass LGSMM_MLE
  • Rewrite the subclass LGSMM_MNIW
  • Rewrite the subclass LGSSM_MNIW_STS

I think we can close this since sts moved to its own package.