Example of EKF with an input?
cbrummitt opened this issue · 4 comments
Does anyone have any tips on how to modify the example of the extended Kalman filter/smoother to have an input, such as
I see that the type annotation of the dynamics_function
argument of ParamsNLGSSM
allows for it to map the state and input to the state: f(x, u) = x_next. Similarly, the emission_function
argument can map the state and input, (x, u), to the measurement.
We should be able to replicate the current version of that example with the trivial input of rngs
with a pre-computed array of inputs u = jnp.zeros((num_steps, 2))
and pass that into the xs
argument of lax.scan
, but I'm struggling to get the _step
to work. Is this at all on the right track, to pass the input u
by attaching it to the rngs
in the xs
argument of lax.scan
?
Based on google/jax#763 I see I can pass a tuple to xs
, like so:
_, (states, observations) = lax.scan(
f=_step,
init=params.initial_state,
xs=(rngs, u)
)
Then the _step
function just needs to unpack its second argument x
as rng, u = x
.
To learn how to add exogenous, time-dependent inputs, I'm considering a model that has torque
I added u
to the dynamics function:
dynamics_function: Callable = lambda x, u: jnp.array(
[x[0] + x[1] * dt + u[0],
x[1] - g * jnp.sin(x[0]) * dt + u[1]
]
)
I set
%matplotlib inline
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jaxtyping import Float, Array
from typing import Callable, NamedTuple
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_smoother, unscented_kalman_smoother
dt = 0.0125
g = 9.8
q_c = 1
r = 0.3
class PendulumParamsWithInput(NamedTuple):
initial_state: Float[Array, "state_dim"] = jnp.array(
[jnp.pi / 2, 0])
dynamics_function: Callable = lambda x, u: jnp.array(
[
x[0] + x[1] * dt + u[0],
x[1] - g * jnp.sin(x[0]) * dt + u[1]
]
)
dynamics_covariance: Float[Array, "state_dim state_dim"] = jnp.array(
[
[q_c * dt**3 / 3, q_c * dt**2 / 2],
[q_c * dt**2 / 2, q_c * dt]
]
)
# the emission is also perturbed by u[0]:
emission_function: Callable = lambda x, u: jnp.array([jnp.sin(x[0])]) + u[0]
emission_covariance: Float[Array, "emission_dim"] = jnp.eye(1) * (r**2)
# Torque applied as a function of time, e.g., u(t) = sin(2πt)
torque: Callable = lambda time_step: 0.1 * jnp.sin(2 * jnp.pi * time_step * dt)
# Pendulum simulation (Särkkä Example 3.7)
def simulate_pendulum_with_input(
params, key=0, num_steps=400
):
if isinstance(key, int):
key = jr.PRNGKey(key)
# Unpack parameters
M, N = params.initial_state.shape[0], params.emission_covariance.shape[0]
f, h = params.dynamics_function, params.emission_function
Q, R = params.dynamics_covariance, params.emission_covariance
u = jnp.hstack(
[
jnp.zeros((num_steps, 1)),
params.torque(jnp.arange(num_steps)).reshape(-1, 1),
]
)
def _step(carry, rng_u):
state = carry
rng, u = rng_u
rng1, rng2 = jr.split(rng, 2)
next_state = f(state, u) + jr.multivariate_normal(rng1, jnp.zeros(M), Q)
obs = h(next_state, u) + jr.multivariate_normal(rng2, jnp.zeros(N), R)
return next_state, (next_state, obs)
rngs = jr.split(key, num_steps)
_, (states, observations) = lax.scan(_step, params.initial_state, (rngs, u))
return states, observations
params = PendulumParamsWithInput()
num_steps = 400
states, observations = simulate_pendulum_with_input(params, num_steps=num_steps)
time_grid = jnp.arange(num_steps) * dt
u = params.torque(jnp.arange(num_steps))
plt.plot(time_grid, states[:, 0], marker='o', markersize=2, label=r'angle $\alpha(t)$')
plt.plot(time_grid, u, label='torque $u(t)$')
plt.scatter(time_grid, observations[:, 0], facecolors='none', edgecolors='k', label='observations')
plt.xlabel('time $t$')
plt.legend()
I'm not sure why the measurements diverge from the true angle:
Trying the extended_kalman_smoother
on this data
ekf_params = ParamsNLGSSM(
initial_mean=params.initial_state,
initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
dynamics_function=params.dynamics_function,
dynamics_covariance=params.dynamics_covariance,
emission_function=params.emission_function,
emission_covariance=params.emission_covariance,
)
ekf_posterior = extended_kalman_smoother(ekf_params, observations)
is giving an error I haven't yet figured out:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[22], line 10
1 ekf_params = ParamsNLGSSM(
2 initial_mean=params.initial_state,
3 initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
(...)
7 emission_covariance=params.emission_covariance,
8 )
---> 10 ekf_posterior = extended_kalman_smoother(ekf_params, observations)
File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:206, in extended_kalman_smoother(params, emissions, filtered_posterior, inputs)
204 # Get filtered posterior
205 if filtered_posterior is None:
--> 206 filtered_posterior = extended_kalman_filter(params, emissions, inputs=inputs)
207 ll = filtered_posterior.marginal_loglik
208 filtered_means = filtered_posterior.filtered_means
File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:153, in extended_kalman_filter(params, emissions, num_iter, inputs, output_fields)
151 # Run the extended Kalman filter
152 carry = (0.0, params.initial_mean, params.initial_covariance)
--> 153 (ll, *_), outputs = lax.scan(_step, carry, jnp.arange(num_timesteps))
154 outputs = {"marginal_loglik": ll, **outputs}
155 posterior_filtered = PosteriorGSSMFiltered(
156 **outputs,
157 )
[... skipping hidden 9 frame]
File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:129, in extended_kalman_filter.<locals>._step(carry, t)
126 y = emissions[t]
128 # Update the log likelihood
--> 129 H_x = H(pred_mean, u)
130 ll += MVN(h(pred_mean, u), H_x @ pred_cov @ H_x.T + R).log_prob(jnp.atleast_1d(y))
132 # Condition on this emission
File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:16, in <lambda>(x, y)
14 # Helper functions
15 _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
---> 16 _process_fn = lambda f, u: (lambda x, y: f(x)) if u is None else f
17 _process_input = lambda x, y: jnp.zeros((y,1)) if x is None else x
20 def _predict(m, P, f, F, Q, u):
[... skipping hidden 5 frame]
File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/jax/_src/linear_util.py:191, in WrappedFun.call_wrapped(self, *args, **kwargs)
188 gen = gen_static_args = out_store = None
190 try:
--> 191 ans = self.f(*args, **dict(self.params, **kwargs))
192 except:
193 # Some transformations yield from inside context managers, so we have to
194 # interrupt them before reraising the exception. Otherwise they will only
195 # get garbage-collected at some later time, running their cleanup tasks
196 # only after this exception is handled, which can corrupt the global
197 # state.
198 while stack:
TypeError: PendulumParamsWithInput.<lambda>() missing 1 required positional argument: 'u'
Does adding an example with
Hi! Did you manage to make it any further on this problem?
Hi! No, I haven't.
@cbrummitt
Apologies for the delayed response! A couple points:
First:
I'm not sure why the measurements diverge from the true angle:
Can you clarify what you mean by this question? FYI, the observations are not the angles themselves but the sine transforms of the angles, and you can see that if you apply sine to the true angles, it matches what you see in the graph, modulo some noise:
plt.plot(time_grid, states[:, 0], marker='o', markersize=2, label=r'angle $\alpha(t)$')
plt.plot(time_grid, jnp.sin(states[:, 0]), label=r'$\sin(\alpha(t))$');
Second, in order to run extended_kalman_smoother
(or extended_kalman_filter
) using control inputs, you can simply specify them via the inputs
optional arg. For example, in your example:
ekf_params = ParamsNLGSSM(
initial_mean=params.initial_state,
initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
dynamics_function=params.dynamics_function,
dynamics_covariance=params.dynamics_covariance,
emission_function=params.emission_function,
emission_covariance=params.emission_covariance,
)
u = jnp.hstack(
[
jnp.zeros((num_steps, 1)),
params.torque(jnp.arange(num_steps)).reshape(-1, 1),
]
)
ekf_posterior = extended_kalman_smoother(ekf_params, observations, inputs=u)