numagic/lumos

improve ParameterTree handling

Opened this issue · 2 comments

TODO:

  • general documentation improvements
  • make flat dict version accessible to user, as well as methods to flatten and unflatten it? One use-case is that imagine if we export a model for simulink, it will only take array as inputs for the parameters, how does the user know how to flatten a large paraemter tree to flat array?
  • there are places where we call apply_and_forward_with_arrays, during jax tracing or casadi code-gen, it will basically set the parameter of the model to the jax tracer or casadi symbolics, which are undesired artifacts, which we should ensure to set them back to the numeric values (otherwise when the user calls the model again, it won't work, but usually in OCP, the user doesn't call the model, only the jit or compiled function calls, so we don't see it)

to ensure the model parameters are again set back to numpy types, we should create a test:

  • in ocp test
  • create compiled model calls (or maybe just do this as a model test). use both jax and casadi
  • after the compilation, ensure that the model parameters are still numpy types

model_params.set_param(...) modifies the parameters in the model where model_param comes from in place (probably already a dangerous behavior?! maybe should force the use of model.set_recursive_params afterwards?)
model_params.replace_subtree(..) does not modify the parameters used in the model in place!