numagic/lumos

reformulating model I/O

Opened this issue · 1 comments

this is a big one, and only an idea
this is linked to: #68 #45

Option 1)
We could make model I/Os completely flat, just one big inputs that contains everything. In that sense, even a state space model is just a mathematical function with multiple inputs and outputs. And then we mark the inputs and outputs with different meanings, like states, controls etc

Option2)
similar to option 1, but instead of making everything flat, we could use hierarchical inputs (dictionary?)

Tested pure dictionary inputs and outputs, and then taking 2nd derivative of the Lagrangian, see code snippet below:

import jax.numpy as jnp
from jax import jit, hessian
from jax.config import config

# By default we use 64bit as over/underflow are quite likely to happen with 32bit and
# 2nd derivative autograd, without a lot of careful management...
config.update("jax_enable_x64", True)

# For initial phase, we also report verbosely if jax is recompiling.
config.update("jax_log_compiles", 1)

import lumos.numpy as lnp
from lumos.models import ModelMaker

model = ModelMaker.make_model_from_name("SimpleVehicle")
model.names = model.names._replace(
    con_outputs=(
        "slip_ratio_fl",
        "slip_ratio_fr",
        "slip_ratio_rl",
        "slip_ratio_rr",
        "slip_angle_fl",
        "slip_angle_fr",
        "slip_angle_rl",
        "slip_angle_rr",
    )
)

states = model.make_const_dict("states", 1.0)
inputs = model.make_const_dict("inputs", 0.3)
states_dot = model.make_const_dict("states", 2.0)
con_outputs = model.make_const_dict("con_outputs", 1.2)


model_return = model.forward(states, inputs, 0.0)


def lagrange(states, inputs, states_dot, con_outputs):

    model_return = model.forward(states, inputs, 0.0)

    # make a lagrangian with all multiplier equal to 1.0
    lagr = 0.0

    for name in states_dot:
         lagr += model_return.states_dot[name] - states_dot[name]

    for name in con_outputs:
         lagr += model_return.outputs[name] - con_outputs[name]

    for name, val in model_return.residuals.items():
         lagr += val

    return lagr


with lnp.use_backend("jax"):
    hlagr = jit(hessian(lagrange))
    test_outputs = hlagr(states, inputs, states_dot, con_outputs)
print("done")

The resulting JIT compilation time is much longer than what we had before, suggesting using pure dict I/O for hessian is also not efficient for JAX (didn't test runtime). The test was repeated also for removing elements from the Lagrangian, but showed no changes (is this maybe a bit suspicious?)