Error when taking gradient wrt parameters in BoxOSQP
deasmhumhna opened this issue · 0 comments
deasmhumhna commented
import jaxopt
fun = lambda z, params_obj: 0.5 * z @ z + params_obj[0] @ z - 1
matvec_A = lambda params_A, z: (z, )
solver = jaxopt.BoxOSQP(matvec_A=matvec_A, fun=fun, tol=1e-5)
def test_loss(a):
params_obj = (jnp.atleast_1d(a,),)
l = (jnp.array([0.]),)
u = (jnp.array([1.]),)
init_params = solver.init_params(
init_x=jnp.array([0.]),
params_obj=params_obj,
params_eq=None,
params_ineq=(l, u)
)
sol = solver.run(
init_params=init_params,
params_obj=params_obj,
params_eq=None,
params_ineq=(l, u)
)
zopt = sol.params.primal[-1][-1]
return fun(zopt, params_obj)
print(test_loss(jnp.array([-0.5]))) # -1.125
print(jax.grad(test_loss)(jnp.array([1.]))) # error
Relevant traceback:
JaxStackTraceBeforeTransformation: TypeError: unsupported operand type(s) for @: 'tuple' and 'tuple'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
[<ipython-input-24-f2a2a3471989>](https://localhost:8080/#) in <cell line: 28>()
26
27 print(test_loss(jnp.array([-0.5]))) # -1.0625
---> 28 print(jax.grad(test_loss)(jnp.array([1.]))) # error
[... skipping hidden 12 frame]
[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in solver_fun_bwd(tup, cotangent)
234
235 # Compute VJPs w.r.t. args.
--> 236 vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
237 args=ba_args[1:], cotangent=cotangent, solve=solve)
238 # Prepend None as the vjp for init_params.
[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in root_vjp(optimality_fun, sol, args, cotangent, solve)
58 return optimality_fun(sol, *args)
59
---> 60 _, vjp_fun_sol = jax.vjp(fun_sol, sol)
61
62 # Compute the multiplication A^T u = (u^T A)^T.
[... skipping hidden 7 frame]
[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in fun_sol(sol)
56 def fun_sol(sol):
57 # We close over the arguments.
---> 58 return optimality_fun(sol, *args)
59
60 _, vjp_fun_sol = jax.vjp(fun_sol, sol)
[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in optimality_fun(params, params_obj, params_eq, params_ineq)
352 primal_var, eq_dual_var, ineq_dual_var = params
353
--> 354 stationarity = grad_fun(primal_var, params_obj)
355
356 if eq_dual_var is not None:
[... skipping hidden 10 frame]
[<ipython-input-24-f2a2a3471989>](https://localhost:8080/#) in <lambda>(z, params_obj)
1 import jaxopt
2
----> 3 fun = lambda z, params_obj: 0.5 * z @ z + params_obj[0] @ z - 1
4 matvec_A = lambda params_A, z: (z, )
5 solver = jaxopt.BoxOSQP(matvec_A=matvec_A, fun=fun, tol=1e-5)
TypeError: unsupported operand type(s) for @: 'tuple' and 'tuple'
Does optimality_fun
/grad_fun
not alter the original function fun
to handle tangents properly?
I can successful get the gradient using the (Q, c)
and matvec_Q
paths.
I can write my actual function using either of these but I imagine this might be difficult for other operations, which I assume is the logic for including fun
.