mathurinm/celer

Using warm start raises IndexError with MultiTaskLasso

PABannier opened this issue · 0 comments

Hi @mathurinm !

I ran into an issue when using Celer's MultiTaskLasso. I'm using Celer's MTL as a basis for Reweighed MTL.
Everything works perfectly fine when using MultiTaskLasso with fit_intercept=False. However, when I enable warm_start, the following error is returned:

IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed.

Here is a snippet of code to reproduce the error:

import numpy as np
from celer import MultiTaskLasso

n_samples = 10
n_features = 15
n_tasks = 5


def generate_data():
    X = np.random.randn(n_samples, n_features)
    Y = np.random.randn(n_samples, n_tasks)
    return X, Y


clf = MultiTaskLasso(alpha=0.3, fit_intercept=False, warm_start=True)

X, Y = generate_data()
clf.fit(X, Y)

X_bis, Y_bis = generate_data()
clf.fit(X_bis, Y_bis)

Anyway, thanks for the work with Celer!