pass inputs into the LDS model
weigcdsb opened this issue · 3 comments
Hello,
I have a very basic question: how to pass N X T X D inputs ("X") into the LDS model (N trials, T time steps and D dimensional inputs)?
In the linear_gaussian_ssm model.py file, the inputs is Optional[Float[Array, "ntime input_dim"]], so there's no dimension for trials (N)?
I tried to do things as in the Kalman filter/ smoother example. But the problem is that I also need to include d latent trajectoreis into the model (i.e. the state dimension should be D + d, if I encode the covariates into the emission matrix).
Not sure how to do it correctly...
Hi @weigcdsb, I'm not sure I totally understand your use case, would you be able to explain it in some more detail and we'll see if I can help 😄.
In general, it should be possible to use jax.vmap
to map filtering/smoothing over additional dimensions (as described here), however this might be be suitable for all scenarios.
@gileshd, thanks for replying & sorry for confusions.
Just use the notations in the comment of your models.py file:
, where input_dim
(assume input_dim=D
, defaults to 0). If there are emission_dim = N
. So the total inputs (stack all
My question is how can I pass the input inputs: Optional[Float[Array, "ntime input_dim"]]=None
, which means the dimension should be
Hope this clarifies my question.
Correct. The input vector u_t at each time step must be a D-dimensional vector. So inputs
has shape (T,D)
(or None
). You can always flatten your 3d inputs outside of dynamax.