
Using a neural network as the terminal cost with l4casadi + acados

Closed this issue · 3 comments


Not an issue but a question. I'm hoping to get confirmation that it is possible to use a neural network as the terminal cost for an MPC built through l4casadi + acados. This would be similar to what's been done in the AC4MPC paper, however I couldn't locate their code.

I'd like to do something similar to ocp.model.cost_expr_ext_cost_e = l4c_y_expr.

Thanks in advance!


I set the external cost to a l4casadi model using the code below, can someone confirm that this is the correct idea?

import casadi as cs
import numpy as np
import torch
import l4casadi as l4c
from acados_template import AcadosOcpSolver, AcadosOcp, AcadosModel
import time


class MultiLayerPerceptron(torch.nn.Module):
    def __init__(self):

        self.input_layer = torch.nn.Linear(2, 512)

        hidden_layers = []
        for i in range(2):
            hidden_layers.append(torch.nn.Linear(512, 512))

        self.hidden_layer = torch.nn.ModuleList(hidden_layers)
        self.out_layer = torch.nn.Linear(512, 1)

        # Model is not trained -- setting output to zero
        with torch.no_grad():

    def forward(self, x):
        x = self.input_layer(x)
        for layer in self.hidden_layer:
            x = torch.tanh(layer(x))
        x = self.out_layer(x)
        return x

class DoubleIntegratorWithLearnedDynamics:
    def __init__(self, learned_dyn):
        self.learned_dyn = learned_dyn

    def model(self):
        s = cs.MX.sym('s', 1)
        s_dot = cs.MX.sym('s_dot', 1)
        s_dot_dot = cs.MX.sym('s_dot_dot', 1)
        u = cs.MX.sym('u', 1)
        x = cs.vertcat(s, s_dot)
        x_dot = cs.vertcat(s_dot, s_dot_dot)

        res_model = self.learned_dyn(x)

        f_expl = cs.vertcat(
        ) + res_model

        x_start = np.zeros((2, ))

        # store to struct
        model = cs.types.SimpleNamespace()
        model.x = x
        model.xdot = x_dot
        model.u = u
        model.z = cs.vertcat([])
        model.p = cs.vertcat([])
        model.f_expl = f_expl
        model.x_start = x_start
        model.constraints = cs.vertcat([])
        model.name = "wr"

        return model

class MPC:
    def __init__(self, model, N, external_shared_lib_dir, external_shared_lib_name):
        self.N = N
        self.model = model
        self.external_shared_lib_dir = external_shared_lib_dir
        self.external_shared_lib_name = external_shared_lib_name

    def solver(self):
        return AcadosOcpSolver(self.ocp())

    def ocp(self):
        model = self.model

        t_horizon = 1.
        N = self.N

        # Get model
        model_ac = self.acados_model(model=model)
        model_ac.p = model.p

        # Dimensions
        nx = 2
        nu = 1
        # ny = 1

        # Create OCP object to formulate the optimization
        ocp = AcadosOcp()
        ocp.model = model_ac
        ocp.dims.N = N
        ocp.dims.nx = nx
        ocp.dims.nu = nu
        # ocp.dims.ny = ny
        ocp.solver_options.tf = t_horizon

        if COST == 'LINEAR_LS':
            # Initialize cost function
            ocp.cost.cost_type = 'LINEAR_LS'
            ocp.cost.cost_type_e = 'LINEAR_LS'

            ocp.cost.W = np.array([[1.]])

            ocp.cost.Vx = np.zeros((ny, nx))
            ocp.cost.Vx[0, 0] = 1.
            ocp.cost.Vu = np.zeros((ny, nu))
```cp.cost.Vz = np.array([[]])
            ocp.cost.Vx_e = np.zeros((ny, nx))

            l4c_y_expr = None
            ocp.cost.cost_type = 'EXTERNAL'
            ocp.cost.cost_type_e = 'EXTERNAL'

            x = ocp.model.x

            # Trivial PyTorch index 0
            l4c_y_expr = l4c.L4CasADi(lambda x: x.reshape((2,1)), name='x_expr')
            ocp.model.cost_expr_ext_cost = l4c_y_expr(x)
            ocp.model.cost_expr_ext_cost_e = l4c_y_expr(x)

        # ocp.cost.W_e = np.array([[0.]])
        # ocp.cost.yref_e = np.array([0.])

        # Initial reference trajectory (will be overwritten)
        # ocp.cost.yref = np.zeros(1)

        # Initial state (will be overwritten)
        ocp.constraints.x0 = model.x_start

        # Set constraints
        a_max = 10
        ocp.constraints.lbu = np.array([-a_max])
        ocp.constraints.ubu = np.array([a_max])
        ocp.constraints.idxbu = np.array([0])

        # Solver options
        ocp.solver_options.qp_solver = 'FULL_CONDENSING_HPIPM'
        ocp.solver_options.hessian_approx = 'GAUSS_NEWTON'
        ocp.solver_options.integrator_type = 'ERK'
        ocp.solver_options.nlp_solver_type = 'SQP_RTI'
        ocp.solver_options.model_external_shared_lib_dir = self.external_shared_lib_dir
        if COST == 'LINEAR_LS':
            ocp.solver_options.model_external_shared_lib_name = self.external_shared_lib_name
            ocp.solver_options.model_external_shared_lib_name = self.external_shared_lib_name + ' -l' + l4c_y_expr.name

        return ocp

    def acados_model(self, model):
        model_ac = AcadosModel()
        model_ac.f_impl_expr = model.xdot - model.f_expl
        model_ac.f_expl_expr = model.f_expl
        model_ac.x = model.x
        model_ac.xdot = model.xdot
        model_ac.u = model.u
        model_ac.name = model.name
        return model_ac

def run():
    N = 10
    learned_dyn_model = l4c.L4CasADi(MultiLayerPerceptron(), model_expects_batch_dim=True, name='learned_dyn')

    model = DoubleIntegratorWithLearnedDynamics(learned_dyn_model)
    solver = MPC(model=model.model(), N=N,

    x = []
    x_ref = []
    ts = 1. / N
    xt = np.array([1., 0.])
    opt_times = []

    for i in range(50):
        now = time.time()
        t = np.linspace(i * ts, i * ts + 1., 10)
        yref = np.sin(0.5 * t + np.pi / 2)
        for t, ref in enumerate(yref):
            solver.set(t, "yref", ref)
        solver.set(0, "lbx", xt)
        solver.set(0, "ubx", xt)
        xt = solver.get(1, "x")

        x_l = []
        for i in range(N):
            x_l.append(solver.get(i, "x"))

        elapsed = time.time() - now

    print(f'Mean iteration time: {1000*np.mean(opt_times):.1f}ms -- {1/np.mean(opt_times):.0f}Hz)')

if __name__ == '__main__':


Thanks for reaching out! Could you clarify what your problem is? Is the code giving you an error / are the results not what you expect?


Hey Tim,

The solver was throwing an error but I wasn't sure if it was because of the additional pytorch model or other parts of my implementation. Turns out it was a bug on my end. Thanks though!