google/jaxopt

pytrees bounds for `jaxopt.ScipyBoundedMinimize`

lgrcia opened this issue · 1 comments

When using jaxopt.ScipyBoundedMinimize, if the initial parameters are specified as a dict, how to specify the bounds using a dict-like structure?

The following fails:

init = {"a": 0.1, "b": 0.2}

solver = jaxopt.ScipyBoundedMinimize(fun=fun)
result = solver.run(
    init,
    bounds=(
        {"a": 0.0, "b": 0.0},
        {"a": 1.0, "b": 1.0},
    ),
)

Documentation says: bounds: an optional tuple (lb, ub) of pytrees with structure identical to init_params, representing box constraints, so it's probably my misunderstanding of pytrees structure rather than a bug. Thanks for your help!

I think this work as expected and the issue was in the model parameters!