google/jaxopt

Error when taking gradient wrt parameters in BoxOSQP

deasmhumhna opened this issue · 0 comments

import jaxopt

fun = lambda z, params_obj: 0.5 * z @ z + params_obj[0] @ z - 1
matvec_A = lambda params_A, z: (z, )
solver = jaxopt.BoxOSQP(matvec_A=matvec_A, fun=fun, tol=1e-5)

def test_loss(a):
    params_obj = (jnp.atleast_1d(a,),)
    l = (jnp.array([0.]),)
    u = (jnp.array([1.]),)

    init_params = solver.init_params(
        init_x=jnp.array([0.]), 
        params_obj=params_obj, 
        params_eq=None, 
        params_ineq=(l, u)
    )
    sol = solver.run(
        init_params=init_params,
        params_obj=params_obj, 
        params_eq=None, 
        params_ineq=(l, u)
    )
    zopt = sol.params.primal[-1][-1]
    return fun(zopt, params_obj)

print(test_loss(jnp.array([-0.5]))) # -1.125
print(jax.grad(test_loss)(jnp.array([1.]))) # error

Relevant traceback:

JaxStackTraceBeforeTransformation: TypeError: unsupported operand type(s) for @: 'tuple' and 'tuple'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
[<ipython-input-24-f2a2a3471989>](https://localhost:8080/#) in <cell line: 28>()
     26 
     27 print(test_loss(jnp.array([-0.5]))) # -1.0625
---> 28 print(jax.grad(test_loss)(jnp.array([1.]))) # error

    [... skipping hidden 12 frame]

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in solver_fun_bwd(tup, cotangent)
    234 
    235       # Compute VJPs w.r.t. args.
--> 236       vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
    237                       args=ba_args[1:], cotangent=cotangent, solve=solve)
    238       # Prepend None as the vjp for init_params.

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in root_vjp(optimality_fun, sol, args, cotangent, solve)
     58     return optimality_fun(sol, *args)
     59 
---> 60   _, vjp_fun_sol = jax.vjp(fun_sol, sol)
     61 
     62   # Compute the multiplication A^T u = (u^T A)^T.

    [... skipping hidden 7 frame]

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in fun_sol(sol)
     56   def fun_sol(sol):
     57     # We close over the arguments.
---> 58     return optimality_fun(sol, *args)
     59 
     60   _, vjp_fun_sol = jax.vjp(fun_sol, sol)

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in optimality_fun(params, params_obj, params_eq, params_ineq)
    352     primal_var, eq_dual_var, ineq_dual_var = params
    353 
--> 354     stationarity = grad_fun(primal_var, params_obj)
    355 
    356     if eq_dual_var is not None:

    [... skipping hidden 10 frame]

[<ipython-input-24-f2a2a3471989>](https://localhost:8080/#) in <lambda>(z, params_obj)
      1 import jaxopt
      2 
----> 3 fun = lambda z, params_obj: 0.5 * z @ z + params_obj[0] @ z - 1
      4 matvec_A = lambda params_A, z: (z, )
      5 solver = jaxopt.BoxOSQP(matvec_A=matvec_A, fun=fun, tol=1e-5)

TypeError: unsupported operand type(s) for @: 'tuple' and 'tuple'

Does optimality_fun/grad_fun not alter the original function fun to handle tangents properly?

I can successful get the gradient using the (Q, c) and matvec_Q paths.
I can write my actual function using either of these but I imagine this might be difficult for other operations, which I assume is the logic for including fun.