PermutationEstimator runs infinitely when gap = 0
pinealan opened this issue · 3 comments
This is a bug report.
Using GRP with certain datasets would lead the PermutationEstimator to run indefinitely (or what feels to be infinite), and it'd keep running into a divide by zero warning.
Here's an excerpt of the code where the bug happened
import sage
from sklearn.gaussian_process import GaussianProcessRegressor
...
gpr = GaussianProcessRegressor(...)
gpr.fit(X_train, y_train)
imputer = sage.MarginalImputer(gpr, X_train)
estimator = sage.PermutationEstimator(imputer, "mse")
importance = estimator(X_, y_, bar=False)
and here's the error that keeps being thrown, which fills up stderr
sage/permutation_estimator.py:126: RuntimeWarning: divide by zero encountered in double_scalars
ratio = std / gap
Hi Alan, thanks for pointing this out.
The gap
variable tracks the difference between the largest and smallest SAGE values (see here), and that's used to decide if the confidence intervals on each value are narrow enough to appear converged. It's very surprising to see gap = 0
in practice... Even if the true SAGE values were all equal (quite unusual), we might expect to see differences due to sampling noise.
The most likely situation I can think of where this would happen is if the model didn't actually depend on any of the features and all the SAGE values were equal to zero. Is there any chance that's what's happening here?
Or is there anything else you can tell me about your use case that might be unusual? If not, maybe I should just get some toy data and try this out for myself.
I'm using sage for an auto-ML-ish application where arbitrary datasets may be fed into the system, so what you suggested is very possible.
If model independence on input data is indeed what's happening, is it possible for sage to detect it and terminate the loop, perhaps by throwing an exception? I can work on a PR if you can give me some pointers on how you think this may work
I've updated how we detect convergence, so I believe this won't be a problem anymore with either PermutationEstimator
or IteratedEstimator
. The cleanest way I could think of to deal with this is to enforce a (very low) minimum value for the gap
variable, so that we always detect convergence when std
is zero. (It's a bit trickier for KernelEstimator
, I couldn't come up with a simple fix for that case.)
I'm going to close this issue, but let me know if anything else comes up.