pytrees bounds for `jaxopt.ScipyBoundedMinimize`
lgrcia opened this issue · 1 comments
lgrcia commented
When using jaxopt.ScipyBoundedMinimize
, if the initial parameters are specified as a dict
, how to specify the bounds
using a dict-like structure?
The following fails:
init = {"a": 0.1, "b": 0.2}
solver = jaxopt.ScipyBoundedMinimize(fun=fun)
result = solver.run(
init,
bounds=(
{"a": 0.0, "b": 0.0},
{"a": 1.0, "b": 1.0},
),
)
Documentation says: bounds: an optional tuple (lb, ub) of pytrees with structure identical to init_params, representing box constraints
, so it's probably my misunderstanding of pytrees structure rather than a bug. Thanks for your help!
lgrcia commented
I think this work as expected and the issue was in the model parameters!