Apparently wrong gradients for simple softmax-as-optimization formulation
currymj opened this issue · 9 comments
While trying to use cvxpylayers to get derivatives, I ran into some apparent errors. I've boiled it down to a quite small minimum working example.
This is essentially
where
I would like to get derivatives of the solution with respect to
I'm using SCS as the solver. If I switch to ECOS (which I'm not even sure is "allowed" because of the entropy term) I also still get wrong derivatives though they are different. I don't get any solver warnings and there are no obvious problems if I look at the verbose output.
I suspect I've observed similar issues using PyTorch although don't have a MWE for that case.
Below is some code to formulate this problem, solve it, get derivatives, and make a plot of the values and derivatives to show that the cvxpylayers and analytic solution derivatives disagree.
import cvxpy as cp
import jax
import jax.numpy as jnp
from cvxpylayers.jax import CvxpyLayer
import matplotlib.pyplot as plt
if __name__ == '__main__':
# define the problem, which is
# maximize v*x1 + b*x2 + smooth_coeff*sum(entr(x1, x2))
# s.t. sum(x) == 1
v = cp.Parameter(1)
b = cp.Parameter(1)
x = cp.Variable(2, nonneg=True)
constraints = [
cp.sum(x) == 1.0,
]
smooth_coeff = 0.01
objective = cp.Maximize(v * x[0] + b * x[1] + smooth_coeff * cp.sum(cp.entr(x)))
problem = cp.Problem(objective, constraints)
layer = CvxpyLayer(problem, parameters=[v, b], variables=[x])
# we'll look at the output in coordinate 0 only for testing
def cvx_solution(v_arr, b_arr):
result, = layer(v_arr, b_arr, solver_args={"solve_method": "SCS"})
return result[0]
# optimal solution to this problem is known to be softmax( v / smooth_coeff, b / smooth_coeff )
def analytic_solution(v_arr, b_arr):
return jax.nn.softmax(jnp.concatenate([v_arr, b_arr]) / smooth_coeff)[0]
analytic_value_and_grad = jax.value_and_grad(analytic_solution, argnums=(0,1))
cvx_value_and_grad = jax.value_and_grad(cvx_solution, argnums=(0,1))
# test on fixed value of v
v_arr = jnp.array([0.6])
# vary b from 0 to 1
b_values = jnp.linspace(0, 1, 100)
softmax_values = []
softmax_b_grads = []
cvx_values = []
cvx_b_grads = []
for b_val in b_values:
value, grads = cvx_value_and_grad(v_arr, jnp.array([b_val]))
analytic_value, analytic_grads = analytic_value_and_grad(v_arr, jnp.array([b_val]))
softmax_values.append(analytic_value.item())
softmax_b_grads.append(analytic_grads[1].item())
cvx_values.append(value.item())
cvx_b_grads.append(grads[1].item())
fig, (ax1, ax2, ax4) = plt.subplots(3, 1)
ax1.plot(b_values, softmax_values)
ax1.plot(b_values, cvx_values)
ax1.set_xlabel("b")
ax1.set_title("softmax solution = cvx_solution (both are overlapping on this plot)")
ax2.plot(b_values, softmax_b_grads)
ax2.set_xlabel("b")
ax2.set_title("d softmax(v, b) / db")
ax4.plot(b_values, cvx_b_grads)
ax4.set_xlabel("b")
ax4.set_title("d cvx_solution(v, b) / db")
fig.tight_layout()
plt.show()
Hi! Thanks for the detailed report of this issue here. There was an update to cvxpy/SCS that is causing this: cvxpylayers doesn't yet support derivatives through the new quadratic objective. I just pushed in this PR that should fix this. Can you try installing that version of the code and confirm it fixes this?
It doesn't seem to be fixed, I still get similar-looking plots running the same code above (using both SCS and ECOS). Can you share what output you see on your end?
Just to confirm I'm using the right version, I installed off master (after seeing commit merged in #146) and the output of pip freeze for packages that seemed relevant is:
cvxpy==1.3.0
cvxpylayers @ git+https://github.com/cvxgrp/cvxpylayers.git@755d93fef4319bd1bdb8390f9c98ff0ebcf8bdea
diffcp==1.0.22
ecos==2.0.12
scs==3.2.3
I ran your code and am also still getting the wrong derivatives. I thought it may be coming from the diffcp's LSQR derivatives, but passing 'mode': 'dense'
to solver_args
doesn't seem to help. Maybe there is something else going on with the exponential cone derivatives? I'm not too sure. The newer versions of Jax could have changed the behavior of some operations we're using since we developed the Jax version, but I just ran the unit tests that check the derivatives in other settings and most of those seem to match. It could be good to check and see if this also happens to the same example in PyTorch (or just using cvxpy directly)
Similar results when using PyTorch. I'm not used to using the cvxpy/diffcp interface directly for derivatives but can also give that a shot.
PyTorch version:
import cvxpy as cp
import torch
from cvxpylayers.torch import CvxpyLayer
import matplotlib.pyplot as plt
if __name__ == '__main__':
# define the problem, which is
# maximize v*x1 + b*x2 + smooth_coeff*sum(entr(x1, x2))
# s.t. sum(x) == 1
v = cp.Parameter(1)
b = cp.Parameter(1)
x = cp.Variable(2, nonneg=True)
constraints = [
cp.sum(x) == 1.0,
]
smooth_coeff = 0.01
objective = cp.Maximize(v * x[0] + b * x[1] + smooth_coeff * cp.sum(cp.entr(x)))
problem = cp.Problem(objective, constraints)
layer = CvxpyLayer(problem, parameters=[v, b], variables=[x])
# we'll look at the output in coordinate 0 only for testing
def cvx_solution(v_arr, b_arr):
result, = layer(v_arr, b_arr, solver_args={"solve_method": "SCS"})
return result[0]
# optimal solution to this problem is known to be softmax( v / smooth_coeff, b / smooth_coeff )
def analytic_solution(v_arr, b_arr):
return torch.nn.functional.softmax(torch.cat([v_arr, b_arr]) / smooth_coeff)[0]
# test on fixed value of v
v_arr = torch.tensor([0.6])
# vary b from 0 to 1
b_values = torch.linspace(0, 1, 100)
softmax_values = []
softmax_b_grads = []
cvx_values = []
cvx_b_grads = []
for b_val in b_values:
arr = torch.tensor([b_val], requires_grad=True)
value = cvx_solution(v_arr, arr)
value.backward()
cvx_values.append(value.item())
cvx_b_grads.append(arr.grad.item())
arr = torch.tensor([b_val], requires_grad=True)
softmax_value = analytic_solution(v_arr, arr)
softmax_value.backward()
softmax_values.append(softmax_value.item())
softmax_b_grads.append(arr.grad.item())
fig, (ax1, ax2, ax4) = plt.subplots(3, 1)
ax1.plot(b_values, softmax_values)
ax1.plot(b_values, cvx_values)
ax1.set_xlabel("b")
ax1.set_title("softmax solution = cvx_solution (both are overlapping on this plot)")
ax2.plot(b_values, softmax_b_grads)
ax2.set_xlabel("b")
ax2.set_title("d softmax(v, b) / db")
ax4.plot(b_values, cvx_b_grads)
ax4.set_xlabel("b")
ax4.set_title("d cvx_solution(v, b) / db")
fig.tight_layout()
plt.show()
It seems I also get these results using cvxpy directly (although I'm not 100% sure I'm using it right). Should I also open up an issue on that repo?
import cvxpy as cp
import torch
import matplotlib.pyplot as plt
if __name__ == '__main__':
# define the problem, which is
# maximize v*x1 + b*x2 + smooth_coeff*sum(entr(x1, x2))
# s.t. sum(x) == 1
v = cp.Parameter(1)
b = cp.Parameter(1)
x = cp.Variable(2, nonneg=True)
constraints = [
cp.sum(x) == 1.0,
]
smooth_coeff = 0.01
objective = cp.Maximize(v * x[0] + b * x[1] + smooth_coeff * cp.sum(cp.entr(x)))
# objective = cp.Maximize(v * x[0] + b * x[1] - smooth_coeff * cp.sum(x**2))
problem = cp.Problem(objective, constraints)
print(problem.is_dpp())
def analytic_solution(v_arr, b_arr):
# call pytorch softmax
return torch.nn.functional.softmax(torch.cat([v_arr, b_arr]) / smooth_coeff, dim=0)[0]
# we'll look at the output in coordinate 0 only for testing
v.value = [0.6]
softmax_values = []
softmax_b_grads = []
cvx_values = []
cvx_b_grads = []
b_values = torch.linspace(0, 1, 100)
for b_val in b_values:
v.value = [0.6]
b.value = [b_val.item()]
problem.solve(requires_grad=True, solver="SCS")
cvx_values.append(x.value[0])
x.gradient = [1,0]
problem.backward()
cvx_b_grads.append(b.gradient)
arr = torch.tensor([b_val], requires_grad=True)
softmax_value = analytic_solution(torch.tensor([0.6]), arr)
softmax_value.backward()
softmax_values.append(softmax_value.item())
softmax_b_grads.append(arr.grad.item())
fig, (ax1, ax2, ax4) = plt.subplots(3, 1)
ax1.plot(b_values, softmax_values)
ax1.plot(b_values, cvx_values)
ax1.set_xlabel("b")
ax1.set_title("softmax solution = cvx_solution (both are overlapping on this plot)")
ax2.plot(b_values, softmax_b_grads)
ax2.set_xlabel("b")
ax2.set_title("d softmax(v, b) / db")
ax4.plot(b_values, cvx_b_grads)
ax4.set_xlabel("b")
ax4.set_title("d cvx_solution(v, b) / db")
fig.tight_layout()
plt.show()
Should I also open up an issue on that repo?
I suspect the issue is that https://github.com/cvxgrp/diffcp is not giving the correct derivatives so it may make sense to isolate to just the cone program derivatives and track/cross-post the issue there
This is likely fixed by pull request cvxgrp/diffcp#59 . I won't have time to check thoroughly for the next couple weeks, but if someone else can install the correct updated versions, and runs the code here and observes that the problem is gone, please feel free to close the issue.