Neural State-Space Models and Latent Dynamic Functions
for High-Dimensional Generative Time-Series Modeling
This repository is meant to conceptually introduce and highlight implementation considerations for the recent class of models called Neural State-Space Models (Neural SSMs). They leverage the classic state-space model with the flexibility of deep learning to approach high-dimensional generative time-series modeling and learning latent dynamics functions.
Included is an abstract PyTorch-Lightning training class with several latent dynamic functions that inherit it, as well as common metrics used in their evaluation and training examples on common datasets. Further broken down via implementation is the distinction between system identification and state estimation approaches, which are reminiscent of their classic SSM counterparts and arise from fundamental differences in the underlying choice of their probabilistic graphical model (PGM). This repository (currently) focuses primarily on considerations related to training dynamics models for system identification and forecasting rather than per-frame state estimation or filtering.
Note: This repo is not fully finished and some of the experiments/sections may be incomplete. This is released as public in order to maximize the potential benefit of this repo and hopefully inspire collaboration in improving it. Feel free to check out the "To-Do" section if you're interesting in contributing!
Fig 1. Schematic of the two PGM forms of Neural SSMs.
If you found the information helpful for your work or use portions of this repo in research development, please consider citing one of the following works:
@misc{missel2022torchneuralssm,
title={TorchNeuralSSM},
author={Missel, Ryan},
publisher={Github},
journal={Github repository},
howpublished={\url{https://github.com/qu-gg/torchssm}},
year={2022},
}
@inproceedings{jiangsequentialLVM,
title={Sequential Latent Variable Models for Few-Shot High-Dimensional Time-Series Forecasting},
author={Jiang, Xiajun and Missel, Ryan and Li, Zhiyuan and Wang, Linwei},
booktitle={The Eleventh International Conference on Learning Representations}
}
This section provides an introduction to the concept of Neural SSMs, some common considerations and limitations, and active areas of research. This section assumes some familiarity with state-space models, though little background is needed to gain a conceptual understanding if one is already coming from a latent modeling perspective or Bayesian learning. Resources are available in abundance considering the width and depth of state-space usage, however, this video and modern textbook are good starting points.
Variational Auto-encoders (VAEs): VAEs[28] provide a principled and popular framework to learn the generative model pθ(x|z) behind data x, involving latent variables z that follows a prior distribution p(z). Variational inference over the generative model is facilitated by a variational approximation of the posterior density of latent variables z, in the form of a recognition model qφ(z|x). Parameters of both the generative and recognition models are optimized with the objective to maximize the evidence lower bound (ELBO) of the marginal data likelihood:
where the first term encourages the reconstruction of the observed data, and the second term of Kullback–Leibler (KL) divergence constrains the estimated posterior density of qφ(z|x) with a pre-defined prior p(z), often assumed to be a zero-mean isotropic Gaussian density.An extension of classic state-space models, neural state-space models - at their core - consist of a dynamic function of some latent states z_k and their emission to observations x_k, realized through the equations:
where θz represents the parameters of the latent dynamic function. The precise form of these functions can vary significantly - from deterministic or stochastic, linear or non-linear, and discrete or continuous.Due to their explicit differentiation of transition and emission and leveraging of structured equations, they have found success in learning interpretable latent dynamic spaces[1,2,3], identifying physical systems from non-direct features[4,5,6], and uses in counterfactual forecasting[7,8,14].
Given the fast pace of progress in latent dynamics modeling over recent years, many models have been presented under a variety of terminologies and proposed frameworks - examples being variational latent recurrent models[5,9,10,11,12,22], deep state-space models[1,2,3,7,13,14], and deterministic encoding-decoding models[4,15,16]. Despite differences in appearance, they all adhere to the same conceptual framework of latent variable modeling and state-space disentanglement. As such, here we unify them under the terminology of Neural SSMs and segment them into the two base choices of probabilistic graphical models that they adhere to: system identification and state estimation. We highlight each PGM's properties and limitations with experimental evaluations on benchmark datasets.
The PGM associated with each approach is determined by the latent variable chosen for inference.
Fig 2. Schematic of latent variable PGMs in Neural SSMS.
System states as latent variables (State Estimation): The intuitive choice for the latent variable is the latent state z_k that underlies x_k, given that it is already latent in the system and is directly associated with the observations. The PGM of this form is shown under Fig. 1A where its marginal likelihood over an observed sequence x0:T is written as:
where p(xi | zi) describes the emission model and p(zi | z<i, x<i) describes the latent dynamics function. Given the common intractability of the posterior, parameter inference is performed through a variational approximation of the posterior density q(z0:T | x0:T), expressed as:
Given these two components, the standard training objective of the Evidence Lower Bound Objective (ELBO) is thus derived with the form:
where the first term represents a reconstruction likelihood term over the sequence and the second is a Kullback-Leibler divergence loss between the variational posterior approximation and some assumed prior of the latent dynamics. This prior can come in many forms, either being the standard Gaussian Normal in variational inference, flow-based priors from ODE settings[5], or physics-based priors in problem-specific situations[20]. This is the primary design choice that separates current works in this area, specifically the modeling of the dynamics prior and its learned approximation. Many works draw inspiration for modeling this interaction by filtering techniques in standard SSMs, where a divergence term is constructed between the dynamics-predicted latent state and the data-corrected observation[7,18].
With this formulation, it is easy to see how dynamics models of this type can have a strong reconstructive capacity for the high-dimensional outputs and contain strong short-term predictions. In addition, input-influenced dynamics are inherent to the prediction task, as errors in the predictions of the latent dynamics are corrected by true observations every step. However, given this data-based correction, the resulting inference of q(zi | z<i, x<i) is weakened, and without near-term observations to guide the dynamics function, its long-horizon forecasting is limited[1,3].
System parameters as latent variables (System Identification): Rather than system states, one can instead choose to select the parameters (denoted as θz in Equation 1). With this change, the resulting PGM is represented in Fig. 1B and its marginal likelihood over x0:T is represented now by:
where the resulting output observations are derived from an initial latent state z0 and the dynamics parameters θz. As before, a variational approximation is considered for inference in place of an intractable posterior but now for the density q(θz, z0) instead. Given prior density assumptions of p(θz) and p(z0) in a similar vein as above, the ELBO function for this PGM is constructed as:
where again the first term is a reconstruction likelihood and the terms following represent KL-Divergence losses over the inferred variables.The given formulation here is the most general form for this line of models and other works can be covered under special assumptions or restrictions of how q(θz) and p(θz) are modeled. Original Neural SSM parameter works consider Linear-Gaussian SSMs as the transition function and introduce non-linearity by varying the transition parameters over time as θz0:T[1,2,3]. However, as shown in Fig. [2B1](#latentSchematic), the result of this results in convoluted temporal modeling and devolves into the same state estimation problem as now the time-varying parameters rely on near-term observations for correctness[8,20]. Rather than time-varying, the system parameters could be considered an optimized global variable, in which the underlying dynamics function becomes a Bayesian neural network in a VAE's latent space[5] and is shown in Fig. [2B2](#latentSchematic). Restricting these parameters to be deterministic results in a model of the form presented in Latent ODE[10]. The furthest restriction in forgoing stochasticity in the inference of z0 results in the suite of models as presented in [4].
Regardless of the precise assumptions, this framework builds a strong latent dynamics function that enables long-term forecasting and, in some settings, even full-scale system identification[1,4] of physical systems. This is done at the cost of a harder inference task given no access to dynamics correction during generation and for full identification tasks, often requires a large number of training samples over the potential system state space[4,5].
As the transition dynamics and the observation space are intentionally disconnected in this framework, the problem of inferring a strong initial latent state from which to forecast is an important consideration when designing a neural state-space model[30]. This is primarily a task- and data-dependent choice, in which the architecture follows the data structure. Thankfully, much work has been done in other research directions on designing good latent encoding models. As such, works in this area often draw from them. This section is split into three parts - one on the usual architecture for high-dimensional image tasks, one on lower-dimensional and/or miscellaneous encoders, and one on the different forms of inference for the initial state depending on which sequence portions are observed.
Image-based Encoders: Unsurprisingly, the common architecture used in latent image encoding is a convolutional neural network (CNN) given its inherent bias toward spatial feature extraction[1,3,4,5]. Works are mixed between either having the sequential input reshaped as frames stacked over the channel dimension or simply running the CNN over each observed frame separately and passing the concatenated embedding into an output layer. Regardless of methodology, a few frames are assumed as observations for initialization, as multiple timesteps are required to infer the initial system movement. A subset of works considers second-order latent vector spaces, in which the encoder is explicitly split into two individual position and momenta functions, taking single and multiple frames respectively[5].
Fig N. Visualization of the stacked initial state encoder, modified from [23].
Alternate Encoders: In settings with non-image-based inputs, the initial latent encoder can take on a large variety of forms, ranging anywhere from simple linear/MLP networks in physical systems[5] to graph convolution networks for latent medical image forecasting[20]. Multi-modal and dynamics conditioning inputs can be leveraged via combinations of encoders whose embeddings go through a shared linear function.
Fig N. Visualization of the stacked graph convolutional encoder, modified from [24].
Variables z0, zk, and zinit: Beyond just the inference of this latent variable, there is one more variation that can be seen throughout literature - that of which portions of the input sequence are observed and used in the initial state inference.
Generally, there are 3 forms seen:
- z0 - which uses a sequence x0:k to get z0.
- zk - which uses previous frames x-k:k to get zk to go forward past observed frames.
- zinit - which uses the sequence x0:k to get an abstract initial vector state that the dynamics function starts from.
Throughout literature, these variable names as shown here aren't used (as most works just call it z0 and describe its inference) but we differentiate it specifically to highlight the distinctions. For training purspoes, it is a subtle distinction but potentially has implications for the resulting l ikelihood optimization and learned vector space.
Fig N. Schematic of the difference between z0 and zinit formulations.
Saying that, generally there is a lack of work exploring the considerations for each approach, besides ad-hoc solutions to bridge the gap between the latent encoder and dynamics function distributions[5]. This gap can stem from optimization problems caused by imbalanced reconstruction terms between dynamics and initial states or in cases where the initial state distribution is far enough away from the data distribution of downstream frames.
However, a recent work "Learning Neural State-Space Models: Do we need a state estimator?" [30] is the first detailed study into the considerations of initial state inference, providing ablations across increasing difficulties of datasets and inference forms. In their work, they found that to get competitive performance of neural SSMs on some dynamical systems, more advanced architectures were required (feed-forward or LSTM networks). Notably, they only evaluate on the zk form, varying architectural choices.
A variety of empirical techniques have been proposed to tackle this gap, much in the same spirit of empirical VAE stability 'tricks.' These include separated x0 and x1:T terms (where x0 has a positive weighting coefficient), VAE pre-training for x0, and KL-regularization terms between the output distributions of the encoder and the dynamics flow[1,5]. One personal intuition regarding these two variable approaches and the tricks applied is that there exists a theoretical trade-off between the two formulations and the tricks applied help to empirically alleviate the shortcomings of either approach. This, however, requires experimentation and validation before any claims can be made.
There are three important phases during the forecasting for a neural SSM, that of initial state inference, reconstruction, and extrapolation.
Fig N. Breakdown of the three forecasting phases - initial state inference, reconstruction, and extrapolation.
Initial State Inference: Inferring the initial state and how many frames are required to get a good initialization is fairly domain/problem specific, as each problem may require more or less time to highlight distinctive patterns that enable effective dynamics separation.
Reconstruction: The former refers to the number of timesteps that are used in training, from which the likelihood term is calculated. So far in works, there is no generally agreed upon standard on how many steps to use in this and works can be seen using anywhere from 1 (i.e. next-step prediction) to 60 frames in this portion[4]. Some works frame this as a hyper-parameter to tune in experiments and there is a consideration of computational cost when scaling up to longer sequences. In our experiments, we've noticed a linear scaling in training time w.r.t. this sequence length. (TO-DO) In the Experiments section, we perform an ablation study on how fixed lengths of reconstruction affects the extrapolation ability of models on Hamiltonian systems.
Extrapolation: This phase refers to the arbitrarily long forecasting of frames that goes beyond the length used in the likelihood term during training. It represents whether a model has captured the system dynamics sufficiently to enable long-term forecasting or model energy decay in non-conserving systems. For specific dynamical systems, this can be a difficult task as, at base, there is no training signal to inform the model to learn good extrapolation. Works often highlight metrics independently on reconstruction and extrapolation phases to highlight a model's strength of identification[4].
Training Considerations: It is important to note that the exact structure of how the likelihood loss is formulated plays a role in how this sequence length may affect extrapolation. Having your likelihood incorporate temporal information (e.g. summation over the sequence, trajectory mean, etc.) can have a detrimental effect on extrapolation as the model optimizes with respect to the fixed reconstruction length. Figure N highlights an example of using temporal information in a likelihood term, where there is near flawless reconstruction but immediate forecasting failure when going towards extrapolation.
Fig N. Example of failed extrapolation given an incorrect likelihood term. Red highlights beginning of extrapoolation.
As well, it is often the case where the reconstruction training metrics (e.g. likelihood/pixel MSE) and visualizations will often show strong convergence despite still poor extrapolation. It can sometimes be the case, especially in Neural ODE latent dynamics, that more training than expected is required to enable strong extrapolation. It is an intuition that the vector field may require a longer optimization than just the reconstruction convergence to be robust against error accumulation that impacts long-horizon forecasting.
Fig N. Training vs. Validation pixel MSE metrics, highlight the continued extrapolation learning past training "convergence."
Tips for training good extrapolation in these models include:
- Perform extrapolation in your validation steps such that there is a metric to highlight extrapolation learning over training.
- Use per-frame averages in the likelihood function rather than any form with temporal information.
- Use variable lengths of reconstruction during training, sampling 1-T frames to reconstruct in a given batch.
- If you have long sequences, especially in non-conserving systems, sample a random starting point per batch.
- Train for longer than you might expect, even when training metrics have converged for "reconstruction."
- The integrator choice can affect this, as non-symplectic integrators have known error accumulation which affects the vector state over long horizons[4]
Insofar we have ignored another common and important component of state-space modeling, the incorporation of external controls u that affect the transition function of the state. Controls represent factors that influence the trajectory of a system but are not direct features of the object/system being modeled. For example, an external force such as friction acting on a moving ball or medications given to a patient could be considered controls[8,14]. These allow an additional layer of interpretability in SSMs and even enable counterfactual reasoning; i.e., given the current state, what does its trajectory look like under varying control inputs going forwards? This has myriad uses in medical modeling with counterfactual medicine[14] or physical system simulations[8].
For Neural SSMs, a variety of approaches have been taken thus far depending on the type of latent transition function used.
Linear Dynamics: In latent dynamics still parameterized by traditional linear gaussian transition functions, control incorporation is as easy as the addition of another transition matrix Bt that modifies a control input ut at each timestep[1,2,4,7].
Fig N. Example of control input in a linear transition function[1].
Non-Linear Dynamics: In discrete non-linear transition matrices using either multi-layer perceptrons or recurrent cells, these can be leveraged by either concatenating it to the input vector before the network forward pass or as a data transformation in the form of element-wise addition and a weighted combination[10].
Fig N. Example of control input in a non-linear transition function[1].
Continuous Dynamics: For incorporation into continuous latent dynamics functions, finding the best approaches is an ongoing topic of interest. Thus far, the reigning approaches are:
- Directly jumping the vector field state with recurrent cells[18]
- Influencing the vector field gradient (e.g. neural controlled differential equations)[17]
- Introducing another dynamics mechanism, continuous or otherwise (e.g. neural ODE or attention blocks), that is combined with the latent trajectory z1:T into an auxiliary state h1:T[8,14,25].
Fig N. Visualization of the IMODE architecture, taken from [8].
In this section, specifics on model implementation and the datasets/metrics used are detailed. Specific data generation details are available in the URLs provided for each dataset. The models and datasets used throughout this repo are solely grayscale physics datasets with underlying Hamiltonian laws, such as pendulum and mass-spring sets. Extensions to color images and non-pixel-based tasks (or even graph-based data!) are easily done in this framework, as the only architecture change needed is the structure of the encoder and decoder networks as the state propagation happens solely in a latent space.
The project's folder structure is as follows:
torchssm/
│
├── train.py # Training entry point that takes in user args and handles boilerplate
├── test.py # Testing script to get reconstructions and metrics on a testing set
├── tune.py # Performs a hyperparameter search for a given dataset using Ray[Tune]
├── README.md # What you're reading right now :^)
├── requirements.txt # Anaconda requirements file to enable easy setup
|
├── data/
| ├── <dataset_type> # Name of the stored dynamics dataset (e.g. pendulum)
| ├── generate_bouncingball.py # Dataset generation script for Bouncing Ball
| ├── generate_hamiltonian.py # Dataset generation script for Hamiltonian Dynamics
| └── tar_directory.py # WebDataset generation script
├── experiments/
| └── <model_name> # Name of the dynamics model run
| └── <experiment_type> # Given name for the ran experiment
| └── <version_x>/ # Each experiment type has its sequential lightning logs saved
├── lightning_logs/
| ├── version_0/ # Placeholder lightning log folder
| └── ... # Subsequent runs
├── models/
│ ├── CommonDynamics.py # Abstract PyTorch-Lightning Module to handle train/test loops
│ ├── CommonVAE.py # Shared encoder/decoder Modules for the VAE portion
│ ├── system_identification/
│ └── ... # Specific System-Identification dynamics functions
│ └── state_estimation/
│ └── ... # Specific State-Estimation dynamics functions
├── utils/
│ ├── dataloader.py # WebDataset class to return the dataloaders used in train/val/testing
│ ├── layers.py # PyTorch Modules that represent general network layers
│ ├── metrics.py # Metric functions for evaluation
│ ├── plotting.py # Plotting functions for visualizatin
| └── utils.py # General utility functions (e.g. argparsing, experiment number tracking, etc)
└──
All data used throughout these experiments are available for download here on Google Drive, in which they already come in their WebDataset forms. The total sizes of all sets are under a modest 2GB. However, feel free to generate your own sets using the provided data scripts!
Hamiltonian Dynamics: Provided for evaluation are a WebDataset dataloader and generation scripts for DeepMind's Hamiltonian Dynamics suite[4], a simulation library for 17 different physics datasets that have known underlying Hamiltonian dynamics. It comes in the form of color image sequences of arbitrary length, coupled with the system's ground truth states (e.g., for pendulum, angular velocity and angle). It is well-benchmarked and customizable, making it a perfect testbed for latent dynamics function evaluation. For each setting, the physical parameters are tweakable alongside an optional friction coefficient to construct non-energy conserving systems. The location of focal points and the color of the objects are all individually tuneable, enabling mixed and complex visual datasets of varying latent dynamics.
Fig N. Pendulum-Colors Examples.
For the base presented experiments of this dataset, we consider and evaluate grayscale versions of pendulum and
mass-spring - which conveniently are just the sliced red channel of the original sets. Each set has 50000
training and 5000
testing trajectories sampled at Δt = 1
time intervals. Energy conservation
is preserved without friction and we assume constant placement of focal points for simplicity. Note that the
modification to color outputs in this framework is as simple as modifying the number of channels in the
encoder and decoder.
Bouncing Balls: Additionally, we provide a dataloader and generation scripts for the standard latent dynamics dataset of bouncing balls[1,2,5,7,8], modified from the implementation in [1]. It consists of a ball or multiple balls moving within a bounding box while being affected by potential external effects, e.g. gravitational forces[1,2,5], pong[2], and interventions[8]. The starting position, angle, and velocity of the ball(s) are sampled uniformly between a set range. It is generated with the PyMunk and PyGame libraries. In this repository, we consider two sets - a simple set of one gravitational force and a mixed set of 4 gravitational forces in the cardinal directions with varying strengths. We similarly generate
50000
training and
5000
testing trajectories, however sample them at Δt = 0.1
intervals.
Fig N. Single Gravity Bouncing Ball Example.
Notably, this system is surprisingly difficult to successfully perform long-term generation on, especially in cases of mixed gravities or multiple objects. Most works only report on generation within 5-15 timesteps following a period of 3-5 observation timesteps[1,2,7] and farther timesteps show lost trajectories and/or incoherent reconstructions.
Meta-Learning Datasets: One of the latest research directions for neural SSMs is evaluating the potential of meta-learning to build domain-adaptable latent dynamics functions[26,27,29]. A representative dataset example for this task is the Turbulent Flow dataset that is affected by various buoyancy forces, highlighting a task with partially shared yet heterogeneous dynamics[27].
Fig N. Turbulent Flow Example, sourced from [27].
Multi-System Dynamics: So far in the literature, the majority of works only consider training Neural SSMs on one system of dynamics at a time - with the most variety lying in that of differing trajectory hyper-parameters. The ability to infer multiple dynamical systems under one model (or learn to output dynamical functions given system observations) and leverage similarities between the sets is an ongoing research pursuit - with applications of neural unit hypernetworks[27] and dynamics functions conditioned on sequences via meta-learning[26,29] being the first dives into this.
Other Sets in Literature: Outside of the previous sets, there are a plethora of other datasets that have been explored in relevant work. The popular task of human motion prediction in both the pose estimation and video generation setting has been considered via datasets Human3.6Mil, CMU Mocap, and Weizzman-Action[5,19], though proper experimentation into this area would require problem-specific architectures given the depth of the existing field. Past high-dimensionality and image-space, standard benchmark datasets in time-series forecasting have also been considered, including the M4, Electricity Transformer Temperature (ETT), and the NASA Turbofan Degradation set. Recent works have begun looking at medical applications in inverse image reconstruction and the incorporation of physics-inspired priors[20,29y ]. Regardless of the success of Neural SSMs on these tasks, the task-agnostic factor and principled structure of this framework make it a versatile and appealing option for generative time-series modeling.
Here, details on how the model implementation is structured and running experiments locally are given. As well, an overview of the abstract class implementation for a general Neural SSM and its types are explained.
Provided within this repository is a PyTorch class structure in which an abstract PyTorch-Lightning Module is shared
across all the given models, from which the specific VAE and dynamics functions inherit and override the relevant
forward functions for training and evaluation. Swapping between dynamics functions and PGM type is as easy as passing
in the model's name for arguments, e.g. python3 train.py --model node
. As the implementation is provided in
PyTorch-Lightning, an optimization and boilerplate
library for PyTorch, it is recommended to be familiar at face-level.
For every model run, a new
lightning_logs/
version folder is created as well as a new experiment version
under `experiments` related to the passed in naming arguments. Hyperparameters passed in for this run are both stored in
the Tensorboard instance created as well as in the local files hparams.yaml, config.json
. Default values and available
options can be found in train.py
or by running python3 train.py -h
. During training
and validation sequences, all of the metrics below are automatically tracked and saved into a Tensorboard instance
which can be used to compare different model runs following. Every 5 epochs, reconstruction sequences against the
ground truth for a set of samples are saved to the experiments folder. Currently, only one checkpoint is saved based
on the last epoch ran rather than checkpoints based on the best validation score or at set epochs. Restarting training
from a checkpoint or loading in a model for testing is done currently by specifying the ckpt_path
to the
base experiment folder and the checkpt
filename.
The implemented dynamics functions are each separated into their respective PGM groups, however, they can still share the same general classes. Each dynamics subclass has a
model_specific_loss
function that allows it to
return additional loss values without interrupting the abstract flow. For example, this could be used in a flow-based
prior that has additional KL terms over ODE flow density without needing to override the training_step
function with a duplicate copy. As well, there is additionally model_specific_plotting
to enable custom
plots every training epoch end.
System Identification Models: For the system identification models, we provide a variety of dynamics functions that resemble the general and special
cases detailed above, which are provided in Fig N. below. The most general version is that of the Bayesian Neural ODE,
in which a neural ordinary differential equation[21] is sampled from a set of optimized distributional
parameters and used as the latent dynamics function
z't = fp(θ)(zs)
[5]. A deterministic version
of a standard Neural ODE is similarly provided, e.g.
z't = fθ(zs)
[10,21]. Following that, two forms of a
Recurrent Generative Network are provided, a residual version (RGN-Res) and a full-step version (RGN), that represent
deterministic and discrete non-linear transition functions. RGN-Res is the equivalent of a Neural ODE using a fixed
step Euler integrator while RGN is just a recurrent forward step function.
Additionally, a representation of the time-varying Linear-Gaussian SSM transition dynamics[1,2] (LGSSM) is
provided. And finally, a set of autoregressive models are considered (i.e. Recurrent neural networks, Long-Short Term
Memory networks, Gated Recurrent Unit networks) as baselines. Their PyTorch Cell implementations are used and evaluated
over the entire sequence, passing in the previously predicted state and observation as its inputs.
Training for these models has one mode, that of taking in several observational frames to infer z0 and then outputting a full sequence autonomously without access to subsequent observations. A likelihood function is compared over the full reconstructed sequence and optimized over. Testing and generation in this setting can be done out to any horizon easily and we provide small sample datasets of
200
timesteps to evaluate out to long horizons.
Fig N. Model schematics for system identification's implemented dynamics functions.
State Estimation Models: For the state estimation line, we provide a reimplementation of the classic Neural SSM work Deep Kalman
Filter[7] alongside state estimation versions of the above, provided in Fig. N below. The DKF model modifies
the standard Kalman Filter Gaussian transition function to incorporate non-linearity and expressivity by parameterizing
the distribution parameters with neural networks
zt∼N(G(zt−1,∆t), S(zt−1,∆t))
[7].
The autoregressive versions for this setting can be viewed as a reimplementation of the Variational Recurrent Neural
Network (VRNN), one of the starting state estimation works in Neural SSMs[22]. For the latent correction
step, we leverage a standard Gated Recurrent Unit (GRU) cell and the corrected latent state is what is passed to the
decoder and likelihood function. Notably, there are two settings these models can be run under: reconstruction
and generation. Reconstruction is used for training and incorporates ground truth observations to correct
the latent state while generation is used to test the forecasting abilities of the model, both short- and long-term.
Fig N. Model schematics for state estimation's implemented dynamics functions.
Mean Squared Error (MSE): A common metric used in video and image tasks where its use is in per-frame average over individual pixel error. While a multitude of papers solely uses plots of frame MSE over time as an evaluation metric, it is insufficient for comparison between models - especially in cases where the dataset contains a small object for reconstruction[4]. This is especially prominent in tasks of system identification where a model that fails to predict long-term may end up with a lower average MSE than a model that has better generation but is slightly off in its object placement.
Fig N. Per-Frame MSE Equation.
Valid Prediction Time (VPT): Introduced in [4], the VPT metric is an advance on latent dynamics evaluation over pure pixel-based MSE metrics. For each prediction sequence, the per-pixel MSE is taken over the frames individually, and the minimum timestep in which the MSE surpasses a pre-defined epsilon is considered the 'valid prediction time.' The resulting mean number over the samples is often normalized over the total prediction timesteps to get a percentage of valid predictions.
Fig N. Per-Sequence VPT Equation.
Object Distance (DST): Another potential metric to support evaluation (useful in image-based physics forecasting tasks) is using the Euclidean distance between the estimated center of the predicted object and its ground truth center. Otsu's Thresholding method can be applied to grayscale output images to get binary predictions of each pixel and then the average pixel location of all the "active" pixels can be calculated. This approach can help alleviate the prior MSE issues of metric imbalance as the maximum Euclidean error of a given image space can be applied to model predictions that fail to have any pixels over Otsu's threshold.
Fig N. Per-Frame DST Equation.
where RN is the dimension of the output (e.g. number of image channels) and s, shat are the subsets of "active" predicted pixels.Valid Prediction Distance (VPD): Similar in spirit to how VPT leverages MSE, VPD is the minimum timestep in which the DST metric surpasses a pre-defined epsilon[29]. This is useful in tracking how long a model can generate an object in a physical system before either incorrect trajectories and/or error accumulation cause significant divergence.
Fig N. Per-Sequence VPD Equation.
R2 Score: For evaluating systems where the full underlying latent system is available and known (e.g. image translations of Hamiltonian systems), the goodness-of-fit score R2 can be used per dimension to show how well the latent system of the Neural SSM captures the dynamics in an interpretable way[1,3]. This is easiest to leverage in linear transition dynamics. Ref. [1], while containing linear transition dynamics, mentioned the possibility of non-linear regression via vanilla neural networks, though this may run into concerns of regressor capacity and data sizes. Additionally, incorporating metrics derived from latent disentanglement learning may provide stronger evaluation capabilities.
Fig N. DVBF Latent Space Visualization for R2 scores, sourced from [1,3].
This section details some experiments that evaluate the fundamental aspects of Neural SSMs and the effects of the
framework decisions one can take. Trained model checkpoints and hyperparameter files are provided for each experiment
under experiments/model
. Evaluations are done with the metrics discussed above, as well as visualizations of
animated trajectories over time and latent walk visualizations.
As is common in deep learning and variational inference tasks, the specific choices of hyper-parameters can have a significant impact on the resulting performance and generalization of the model. As such, first we perform a hyper-parameter tuning task for each model on a shared validation set to get eachs' optimized hyper-parameter set. From this, the optimal set for each is carried across the various tasks given similar task complexity.
We provide a Ray[Tune] tuning script to handle training and formatting the Pytorch-Lightning outputs for each model,
found in tune.py
. It automatically parallelizes across GPUs and has a convenient Tensorboard output
interface to compare the tuning runs. In order to run custom tuning tasks, simply create a local folder in the
repository root directory and rename the tune run "name" to redirect the output there. Please refer to RayTune's
relevant documentation for information.
Here we report the results of tuning each of the models on the Hamiltonian physics dataset Pendulum. For each model, we highlight their best-performing hyperparameters with respect to the validation extrapolation MSE. For experiment going forwards, these hyperparameters will be used in experiments of similar complexity.
We test two environments for the Pendulum dataset, a fixed-point one-color pendulum and a multi-point multi-color pendulum set of increased complexity. As described in [4], each individual sequence is sampled from a uniform distribution over physical parameters like mass, gravity, and pivot length. We describe data generation above in the Data section.
[TO-DO: Click to show the results for Fixed-Point Pendulum]
Coming soon.[TO-DO: Click to show the results for Multi-Point Pendulum tuning]
Coming soon.Similar to above, we highlight the results and hyperparameters of each model for the Bouncing Ball dataset.
[TO-DO: Click to show the results for Bouncing Ball tuning]
Coming soon.ODE Solvers: To measure the impact that ODE Solvers have on the optimized dynamics models, we performed an ablation on the available solvers that exist within the torchdiffeq library, including both fixed and adaptive solvers. We make note of their respective training times due to increased solution complexity and train each ODE solver over a variety of parameters depending on their type (e.g. step size or solution tolerances).
[Click to show the results for the ODE Solver ablation]
Coming soon.This section just consists of to-do's within the repo, contribution guidelines, and a section on how to find the references used throughout the repo.
- Generation lengths used in training (e.g. 1/2/3/5/10 frames) - Fixed vs variable generation lengths - z0 inference scheme (no overlap, overlap-by-one, full overlap) - Use of ODE solvers (fixed, adaptive, tolerances) - Different forms of learning rate schedulers - Linear versus CNN decoder - Activation functions in the latent dynamics function- Implement latent walk visualizations against data-space observations (like in DVBF)
- Add guidelines for an
Experiment
section highlighting experiments performed in validating the models - Add a section explaining for
Meta-Learning
works in Neural SSMs - Add a section explaining for
ODE Integrator
considerations in Neural SSMs
Contributions are welcome and encouraged! If you have an implementation of a latent dynamics function you think would be relevant and add to the conversation, feel free to submit an Issue or PR and we can discuss its incorporation. Similarly, if you feel an area of the README is lacking or contains errors, please put up a README editing PR with your suggested updates. Even tackling items on the To-Do would be massively helpful!
- Maximilian Karl, Maximilian Soelch, Justin Bayer, and Patrick van der Smagt. Deep variational bayes filters: Unsupervised learning of state space models from raw data. In International Conference on Learning Representations, 2017.
- Marco Fraccaro, Simon Kamronn, Ulrich Paquetz, and OleWinthery. A disentangled recognition and nonlinear dynamics model for unsupervised learning. In Advances in Neural Information Processing Systems, 2017.
- Alexej Klushyn, Richard Kurle, Maximilian Soelch, Botond Cseke, and Patrick van der Smagt. Latent matters: Learning deep state-space models. Advances in Neural Information Processing Systems, 34, 2021.
- Aleksandar Botev, Andrew Jaegle, Peter Wirnsberger, Daniel Hennes, and Irina Higgins. Which priors matter? benchmarking models for learning latent dynamics. In Advances in Neural Information Processing Systems, 2021.
- C. Yildiz, M. Heinonen, and H. Lahdesmaki. ODE2VAE: Deep generative second order odes with bayesian neural networks. In Neural Information Processing Systems, 2020.
- Batuhan Koyuncu. Analysis of ode2vae with examples. arXiv preprint arXiv:2108.04899, 2021.
- Rahul G. Krishnan, Uri Shalit, and David Sontag. Structured inference networks for nonlinear state space models. In Association for the Advancement of Artificial Intelligence, 2017.
- Daehoon Gwak, Gyuhyeon Sim, Michael Poli, Stefano Massaroli, Jaegul Choo, and Edward Choi. Neural ordinary differential equations for intervention modeling. arXiv preprint arXiv:2010.08304, 2020.
- Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron Courville, and Yoshua Bengio. A recurrent latent variable model for sequential data. In Advances in Neural Information Processing Systems, 2015.
- Yulia Rubanova, Ricky T. Q. Chen, and David Duvenaud. Latent odes for irregularly-sampled time series. In Neural Information Processing Systems, 2019.
- Tsuyoshi Ishizone, Tomoyuki Higuchi, and Kazuyuki Nakamura. Ensemble kalman variational objectives: Nonlinear latent trajectory inference with a hybrid of variational inference and ensemble kalman filter. arXiv preprint arXiv:2010.08729, 2020.
- Justin Bayer, Maximilian Soelch, Atanas Mirchev, Baris Kayalibay, and Patrick van der Smagt. Mind the gap when conditioning amortised inference in sequential latent-variable models. arXiv preprint arXiv:2101.07046, 2021.
- Ðor ̄de Miladinovi ́c, Muhammad Waleed Gondal, Bernhard Schölkopf, Joachim M Buhmann, and Stefan Bauer. Disentangled state space representations. arXiv preprint arXiv:1906.03255, 2019.
- Zeshan Hussain, Rahul G. Krishnan, and David Sontag. Neural pharmacodynamic state space modeling, 2021.
- Francesco Paolo Casale, Adrian Dalca, Luca Saglietti, Jennifer Listgarten, and Nicolo Fusi.Gaussian process prior variational autoencoders. Advances in neural information processing systems, 31, 2018.
- Yingzhen Li and Stephan Mandt. Disentangled sequential autoencoder. arXiv preprint arXiv:1803.02991, 2018.
- Patrick Kidger, James Morrill, James Foster, and Terry Lyons. Neural controlled differential equations for irregular time series. Advances in Neural Information Processing Systems, 33:6696-6707, 2020.
- Edward De Brouwer, Jaak Simm, Adam Arany, and Yves Moreau. Gru-ode-bayes: Continuous modeling of sporadically-observed time series. Advances in neural information processing systems, 32, 2019.
- Ruben Villegas, Jimei Yang, Yuliang Zou, Sungryull Sohn, Xunyu Lin, and Honglak Lee. Learning to generate long-term future via hierarchical prediction. In international conference on machine learning, pages 3560–3569. PMLR, 2017
- Xiajun Jiang, Ryan Missel, Maryam Toloubidokhti, Zhiyuan Li, Omar Gharbia, John L Sapp, and Linwei Wang. Label-free physics-informed image sequence reconstruction with disentangled spatial-temporal modeling. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 361–371. Springer, 2021.
- Ricky TQ Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. Neural ordinary differential equations. Advances in neural information processing systems, 31, 2018.
- Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron C Courville, and Yoshua Bengio. A recurrent latent variable model for sequential data. Advances in neural information processing systems, 28, 2015.
- Junbo Zhang, Yu Zheng, and Dekang Qi. Deep spatio-temporal residual networks for citywide crowd flows prediction. In Thirty-first AAAI conference on artificial intelligence, 2017.
- Yong Han, Shukang Wang, Yibin Ren, Cheng Wang, Peng Gao, and Ge Chen. Predicting station-level short-term passenger flow in a citywide metro network using spatiotemporal graph convolutional neural networks. ISPRS International Journal of Geo-Information, 8(6):243, 2019
- Maryam Toloubidokhti, Ryan Missel, Xiajun Jiang, Niels Otani, and Linwei Wang. Neural state-space modeling with latent causal-effect disentanglement. In International Workshop on Machine Learning in Medical Imaging, 2022.
- Matthieu Kirchmeyer, Yuan Yin, J ́er ́emie Don`a, Nicolas Baskiotis, Alain Rakotomamonjy, and Patrick Gallinari. Generalizing to new physical systems via context-informed dynamics model. arXiv preprint arXiv:2202.01889, 2022.
- Rui Wang, Robin Walters, and Rose Yu. Meta-learning dynamics forecasting using task inference. arXiv preprint arXiv:2102.10271, 2021.
- Kingma Diederik P, Welling Max. Auto-encoding variational bayes // arXiv preprint arXiv:1312.6114.2013.
- Xiajun Jiang, Zhiyuan Li, Ryan Missel, Md Shakil Zaman, Brian Zenger, Wilson W Good, Rob S MacLeod, John L Sapp, and Linwei Wang. Few-shot generation of personalized neural surrogates for cardiac simulation via bayesian meta-learning. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 46–56. Springer, 2022.
- Marco Forgione, Manas Mejari, and Dario Piga. Learning neural state-space models: do we need a state estimator? arXiv preprint arXiv:2206.12928, 2022.