probml/dynamax

Example of EKF with an input?

cbrummitt opened this issue · 4 comments

Does anyone have any tips on how to modify the example of the extended Kalman filter/smoother to have an input, such as $u(t) = 0.1 \cos(t)$?

I see that the type annotation of the dynamics_function argument of ParamsNLGSSM allows for it to map the state and input to the state: f(x, u) = x_next. Similarly, the emission_function argument can map the state and input, (x, u), to the measurement.

We should be able to replicate the current version of that example with the trivial input of $u(t) \equiv 0$. To that end, I tried to combine the random seeds rngs with a pre-computed array of inputs u = jnp.zeros((num_steps, 2)) and pass that into the xs argument of lax.scan, but I'm struggling to get the _step to work. Is this at all on the right track, to pass the input u by attaching it to the rngs in the xs argument of lax.scan?

Based on google/jax#763 I see I can pass a tuple to xs, like so:

_, (states, observations) = lax.scan(
    f=_step,
    init=params.initial_state,
    xs=(rngs, u)
)

Then the _step function just needs to unpack its second argument x as rng, u = x.

To learn how to add exogenous, time-dependent inputs, I'm considering a model that has torque $u(t)$ applied:
$$\frac{d^2 \alpha}{dt^2} = - g \sin(\alpha) + w(t) + u(t).$$

I added u to the dynamics function:

    dynamics_function: Callable = lambda x, u: jnp.array(
        [x[0] + x[1] * dt + u[0],
         x[1] - g * jnp.sin(x[0]) * dt + u[1]
        ]
    )

I set $u(t) = 0.1 \sin(2 \pi t)$, as if the pendulum were on a boat. Below is a stand-alone script that incorporates the above into generating the true value of the angle:

%matplotlib inline
import matplotlib.pyplot as plt

import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jaxtyping import Float, Array
from typing import Callable, NamedTuple

from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_smoother, unscented_kalman_smoother

dt = 0.0125
g = 9.8
q_c = 1
r = 0.3

class PendulumParamsWithInput(NamedTuple):
    initial_state: Float[Array, "state_dim"] = jnp.array(
        [jnp.pi / 2, 0])
    dynamics_function: Callable = lambda x, u: jnp.array(
        [
            x[0] + x[1] * dt + u[0],
            x[1] - g * jnp.sin(x[0]) * dt + u[1]
        ]
    )
    dynamics_covariance: Float[Array, "state_dim state_dim"] = jnp.array(
        [
            [q_c * dt**3 / 3, q_c * dt**2 / 2],
            [q_c * dt**2 / 2, q_c * dt]
        ]
    )
    # the emission is also perturbed by u[0]:
    emission_function: Callable = lambda x, u: jnp.array([jnp.sin(x[0])]) + u[0]
    emission_covariance: Float[Array, "emission_dim"] = jnp.eye(1) * (r**2)

    # Torque applied as a function of time, e.g., u(t) = sin(2πt)
    torque: Callable = lambda time_step: 0.1 * jnp.sin(2 * jnp.pi * time_step * dt)

# Pendulum simulation (Särkkä Example 3.7)
def simulate_pendulum_with_input(
    params, key=0, num_steps=400
):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    # Unpack parameters
    M, N = params.initial_state.shape[0], params.emission_covariance.shape[0]
    f, h = params.dynamics_function, params.emission_function
    Q, R = params.dynamics_covariance, params.emission_covariance
    u = jnp.hstack(
        [
            jnp.zeros((num_steps, 1)),
            params.torque(jnp.arange(num_steps)).reshape(-1, 1),
        ]
    )
    
    def _step(carry, rng_u):
        state = carry
        rng, u = rng_u
        rng1, rng2 = jr.split(rng, 2)

        next_state = f(state, u) + jr.multivariate_normal(rng1, jnp.zeros(M), Q)
        obs = h(next_state, u) + jr.multivariate_normal(rng2, jnp.zeros(N), R)
        return next_state, (next_state, obs)

    rngs = jr.split(key, num_steps)
    _, (states, observations) = lax.scan(_step, params.initial_state, (rngs, u))
    return states, observations

params = PendulumParamsWithInput()
num_steps = 400
states, observations = simulate_pendulum_with_input(params, num_steps=num_steps)


time_grid = jnp.arange(num_steps) * dt
u = params.torque(jnp.arange(num_steps))
plt.plot(time_grid, states[:, 0], marker='o', markersize=2, label=r'angle $\alpha(t)$')
plt.plot(time_grid, u, label='torque $u(t)$')
plt.scatter(time_grid, observations[:, 0], facecolors='none', edgecolors='k', label='observations')
plt.xlabel('time $t$')
plt.legend()

I'm not sure why the measurements diverge from the true angle:
image

Trying the extended_kalman_smoother on this data

ekf_params = ParamsNLGSSM(
    initial_mean=params.initial_state,
    initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
    dynamics_function=params.dynamics_function,
    dynamics_covariance=params.dynamics_covariance,
    emission_function=params.emission_function,
    emission_covariance=params.emission_covariance,
)

ekf_posterior = extended_kalman_smoother(ekf_params, observations)

is giving an error I haven't yet figured out:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[22], line 10
      1 ekf_params = ParamsNLGSSM(
      2     initial_mean=params.initial_state,
      3     initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
   (...)
      7     emission_covariance=params.emission_covariance,
      8 )
---> 10 ekf_posterior = extended_kalman_smoother(ekf_params, observations)

File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:206, in extended_kalman_smoother(params, emissions, filtered_posterior, inputs)
    204 # Get filtered posterior
    205 if filtered_posterior is None:
--> 206     filtered_posterior = extended_kalman_filter(params, emissions, inputs=inputs)
    207 ll = filtered_posterior.marginal_loglik
    208 filtered_means = filtered_posterior.filtered_means

File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:153, in extended_kalman_filter(params, emissions, num_iter, inputs, output_fields)
    151 # Run the extended Kalman filter
    152 carry = (0.0, params.initial_mean, params.initial_covariance)
--> 153 (ll, *_), outputs = lax.scan(_step, carry, jnp.arange(num_timesteps))
    154 outputs = {"marginal_loglik": ll, **outputs}
    155 posterior_filtered = PosteriorGSSMFiltered(
    156     **outputs,
    157 )

    [... skipping hidden 9 frame]

File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:129, in extended_kalman_filter.<locals>._step(carry, t)
    126 y = emissions[t]
    128 # Update the log likelihood
--> 129 H_x = H(pred_mean, u)
    130 ll += MVN(h(pred_mean, u), H_x @ pred_cov @ H_x.T + R).log_prob(jnp.atleast_1d(y))
    132 # Condition on this emission

File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:16, in <lambda>(x, y)
     14 # Helper functions
     15 _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
---> 16 _process_fn = lambda f, u: (lambda x, y: f(x)) if u is None else f
     17 _process_input = lambda x, y: jnp.zeros((y,1)) if x is None else x
     20 def _predict(m, P, f, F, Q, u):

    [... skipping hidden 5 frame]

File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/jax/_src/linear_util.py:191, in WrappedFun.call_wrapped(self, *args, **kwargs)
    188 gen = gen_static_args = out_store = None
    190 try:
--> 191   ans = self.f(*args, **dict(self.params, **kwargs))
    192 except:
    193   # Some transformations yield from inside context managers, so we have to
    194   # interrupt them before reraising the exception. Otherwise they will only
    195   # get garbage-collected at some later time, running their cleanup tasks
    196   # only after this exception is handled, which can corrupt the global
    197   # state.
    198   while stack:

TypeError: PendulumParamsWithInput.<lambda>() missing 1 required positional argument: 'u'

Does adding an example with $u(t)$ to the tutorial Tracking a 1d pendulum using Extended / Unscented Kalman filter/ smoother seem of potential interest?

Hi! Did you manage to make it any further on this problem?

Hi! No, I haven't.

@cbrummitt
Apologies for the delayed response! A couple points:

First:

I'm not sure why the measurements diverge from the true angle:

Can you clarify what you mean by this question? FYI, the observations are not the angles themselves but the sine transforms of the angles, and you can see that if you apply sine to the true angles, it matches what you see in the graph, modulo some noise:

plt.plot(time_grid, states[:, 0], marker='o', markersize=2, label=r'angle $\alpha(t)$')
plt.plot(time_grid, jnp.sin(states[:, 0]), label=r'$\sin(\alpha(t))$');

gets you:
example

Second, in order to run extended_kalman_smoother (or extended_kalman_filter) using control inputs, you can simply specify them via the inputs optional arg. For example, in your example:

ekf_params = ParamsNLGSSM(
    initial_mean=params.initial_state,
    initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
    dynamics_function=params.dynamics_function,
    dynamics_covariance=params.dynamics_covariance,
    emission_function=params.emission_function,
    emission_covariance=params.emission_covariance,
)
u = jnp.hstack(
    [
        jnp.zeros((num_steps, 1)),
        params.torque(jnp.arange(num_steps)).reshape(-1, 1),
    ]
)
ekf_posterior = extended_kalman_smoother(ekf_params, observations, inputs=u)

would get you
output
Please let me know if anything isn't clear!