Using warm start raises IndexError with MultiTaskLasso
PABannier opened this issue · 0 comments
PABannier commented
Hi @mathurinm !
I ran into an issue when using Celer's MultiTaskLasso. I'm using Celer's MTL as a basis for Reweighed MTL.
Everything works perfectly fine when using MultiTaskLasso with fit_intercept=False
. However, when I enable warm_start
, the following error is returned:
IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed
.
Here is a snippet of code to reproduce the error:
import numpy as np
from celer import MultiTaskLasso
n_samples = 10
n_features = 15
n_tasks = 5
def generate_data():
X = np.random.randn(n_samples, n_features)
Y = np.random.randn(n_samples, n_tasks)
return X, Y
clf = MultiTaskLasso(alpha=0.3, fit_intercept=False, warm_start=True)
X, Y = generate_data()
clf.fit(X, Y)
X_bis, Y_bis = generate_data()
clf.fit(X_bis, Y_bis)
Anyway, thanks for the work with Celer!