mathurinm/celer

Floating point error accumulation

Closed this issue · 2 comments

Hello,
The GroupLasso on this example results in a suboptimal value.

Reproduce

Using the following coefficients to compare primal value

run the code

from celer.datasets import make_correlated_data
from celer import GroupLasso
import numpy as onp
group_size = 300
n_samples = 600
n_groups = 4
n_features = n_groups * group_size

X, y, _ = make_correlated_data(n_samples=n_samples, n_features=n_features, random_state=12)

estimator = GroupLasso(groups=group_size, prune=False, alpha=1.0, tol=1e-20, p0=5, max_iter=200, max_epochs=5000, verbose=12)
estimator.fit(X, y)

Xg = X.reshape(n_samples, n_groups, group_size).swapaxes(0, 1)
assert onp.allclose(Xg[0].flatten(), X[:, :group_size].flatten()), "groups are not corrtectly reshaped"

def primal_onp(A, lambda_, x, y, n_samples):
    return 1 /(2 *n_samples) * onp.linalg.norm(onp.einsum('ijk,ik', A, x) - y)**2 + lambda_ * onp.sum(onp.linalg.norm(x, ord=2, axis=-1))

celer_primal = primal_onp(Xg, 1.0, estimator.coef_.reshape((n_groups, group_size)), y, n_samples)
print(f"primal value: {celer_primal:.6E}")

best = onp.loadtxt("best.txt")
best_p = primal_onp(Xg, 1.0, best, y, n_samples)
print(f"best primal value: {best_p:.6E}")
print(f"suboptimality: {celer_primal - best_p:.6E}")

to obtain

#########################
##### Computing alpha 1/1
#########################
Iter 0: primal 139.4153891823, gap 1.23e+02, 4 groups in subpb (4 left)
Iter 1: primal 22.8615244139, gap -1.42e-14
Early exit, gap: -1.42e-14 < 2.79e-18
primal value: 2.298344E+01
best primal value: 2.288588E+01
suboptimality: 9.756279E-02

Result

The primal suboptimality is at least 0.097.

Expected

The primal suboptimality is around the tolerance 1e-20 or closer to machine precision / the reported dual gap in e-14.

Probable cause

I had the same issue (negative dual gaps) on my own solver and the source was floating point error accumulation in the dual variable Xw, as we assume that Xw matches the proper value X @ w when computing the primal objective.

The example has large groups to increase the number of floating point operations per coordinate update. That way errors will appear faster. However, floating point errors accumulation should also affect the (not Group) Lasso.

Tested on

  • celer 0.7.3
  • numpy 1.26.4
  • platform Apple M1

Hi @joelgarde ! Thanks for pointing this out to us, this is so great to see Celer being used in the wild.
Here are the debugging step:

  • This is not a stability issue, we get a difference of 1e-2 which is not due to numerical errors. Numerical errors would be small, like 1e-14

  • By using the verbose mode we see that tomething is wrong in the primal computation, because it does not match the objective that is displayed inside the celer loop :

#########################
##### Computing alpha 1/1
#########################
Iter 0: primal 139.4153891823, gap 1.23e+02, 4 groups in subpb (4 left)
Iter 1: primal 22.8615244139, gap -7.11e-15
Early exit, gap: -7.11e-15 < 2.79e-18
primal value: 2.298344E+01
best primal value: 2.288588E+01
suboptimality: 9.756279E-02

In [2]: celer_primal
Out[2]: 22.983439744552424

In [3]: best_p
Out[3]: 22.885876951049244

Have a look at this line: Iter 1: primal 22.8615244139 (which is lower than the best_p 🥇)

  • The difference is due to the fact that celer, by compatibility with sklearn, fits an intercept by default. So the residuals are not A @ x - y but A @ x + intercept - y. See the difference here:
In [25]: (X @ estimator.coef_)[:10]
Out[25]: 
array([ 10.91349885, -11.4947074 ,  34.10608126,   3.8771897 ,
       -12.40119896,  21.48192526,  25.52849388,  33.56151975,
        -2.23323574,   9.70506466])

In [26]: estimator.predict(X[:10])
Out[26]: 
array([ 11.40729098, -11.00091528,  34.59987338,   4.37098182,
       -11.90740684,  21.97571739,  26.022286  ,  34.05531187,
        -1.73944362,  10.19885679])
        
In [27]: (X @ estimator.coef_ + estimator.intercept_)[:10]
Out[27]: 
array([ 11.40729098, -11.00091528,  34.59987338,   4.37098182,
       -11.90740684,  21.97571739,  26.022286  ,  34.05531187,
        -1.73944362,  10.19885679])

I'm closing, feel free to reopen if this was not the end of the story

Great news, I worried for nothing!

Re-tried without intercept and everything's fine up to floating point limits.
Thanks for the quick response.