rust-ml/linfa

MultiLogisticRegression panics on normalized data.

Opened this issue · 0 comments

Firstly, let me say I'm very new to data science / ML so my understanding / terminology may be wrong. Please bare with me, thanks in advance.

I'm using a relatively small dataset (769 features over 6893 samples) and a small number of categories (8). All of my weights are normalized between [0,1], though most are zero. I'm using the default configuration:

let model = MultiLogisticRegression::default().fit(&dataset).unwrap();
thread 'main' panicked at src/train.rs:121:22:
called `Result::unwrap()` on an `Err` value: ArgMinError(Condition violated: "`MoreThuenteLineSearch`: Search direction must be a descent direction.")

I've observed that if I set alpha to a larger value like 10, I get a result without panicking. I've also noticed that if I limit the number of iterations to a very small number, say, 20, I also get a good result without panicking. Therefore, I think the culprit is overfitting / divergence (uncertain of the proper terminology here).

I will say that I was able to use smartcore's multinomial logistic regression routines while setting alpha = 0 without issue. Notably, they use a Backtracking line search implementation instead of More-Thuente. I don't know if that's relevant or not.

I think this situation is related to this comment on the original MultiLogisticRegression PR regarding divergence. If this is something that can be addressed by linfa by using a different line search algorithm with different numerical requirements, great. If changing the default line search algorithm is undesirable, then at least letting users configure the algorithm used would be greatly appreciated. Most of all however, I would suggest printing a significantly more helpful error message when this divergence happens. If linfa could catch the error returned by argmin and translate it to something along the lines of "your dataset diverged, please increase alpha or reduce the number of iterations", I imagine you would save a ton of developer troubleshooting hours.

Thanks again, and please let me know if there's any other details I should provide here.