bgreenwell/fastshap

Calculating Shapley Values with Data Constraints

JonP-16 opened this issue · 4 comments

I'm trying to calculate Shapley values for a dataset that has constrained regions - the data just doesn't exist or is impossible in the real world. I've tried to illustrate this phenomena in an example below where I have two constraints listed. Right now, I have the pred_wrapper return NA if there's any sort of input that attempts to produce a prediction. As such, the results (labeled shap in the code below), are greatly diminished (there are only 74 returned values as opposed to the 217 we started with). This effect is even more pronounced with the more iterations (nsim) used.

I realize this isn't necessarily a bug given the way that Shapley values are calculated, but I suppose I'm either looking for advice on how to deal with constraints or I'm submitting a feature request/enhancement.

library(fastshap)  # for fast (approximate) Shapley values
library(ranger)    # for fast random forest algorithm

trn <- gen_friedman(250, seed = 101)
constraint_1_inds = which(trn$x1 < 0.35 & trn$x5 > 0.85)
constraint_2_inds = which(trn$x3 < 0.25 & trn$x9 > 0.75)
constraint_inds   = unique(c(constraint_1_inds, constraint_2_inds))
trn <- trn[-constraint_inds, ]
dim(trn) # lost 33 data points; now at 217 instead of 250

X <- subset(trn, select = -y)  # feature columns only

set.seed(102)
rfo <- ranger(y ~ ., data =  trn)

pfun <- function(object, newdata) {
  preds = predict(object, data = newdata)$predictions
  constraint_1_inds = which(newdata$x1 < 0.35 & newdata$x5 > 0.85)
  constraint_2_inds = which(newdata$x3 < 0.25 & newdata$x9 > 0.75)
  constraint_inds   = unique(c(constraint_1_inds, constraint_2_inds))
  preds[constraint_inds] = NA
  return(preds)
}

set.seed(5038)
shap <- explain(rfo, X = X, pred_wrapper = pfun, nsim = 10)

shap[complete.cases(shap), ]# This is only 74 cases

Hi @JonP-16, I'm not sure I quite follow. This seems like it would cause issue since pfun() will potentially return predictions for only a subset of newdata. Would it not be more reasonable to filter newdata first?

Hi!

Sorry for the delay on this. Perhaps a better structure to explain the issue would be something like this:

pfun_1  <- function(object, newdata) {
  preds = rep(NA, nrow(newdata))
  for (i in 1:nrow(newdata)){
    if ( (newdata$x1[i] < 0.35 & newdata$x5[i] > 0.85) | (newdata$x3[i] < 0.25 & newdata$x9[i] > 0.75) ){
      preds[i] = NA
    } else {
      preds[i] = predict(object, data = newdata[i, ])$predictions
    }
    return(preds)
  }
}

pfun_2 <- function(object, newdata) {
  constraint_1_inds = which(newdata$x1 < 0.35 & newdata$x5 > 0.85)
  constraint_2_inds = which(newdata$x3 < 0.25 & newdata$x9 > 0.75)
  constraint_inds   = unique(c(constraint_1_inds, constraint_2_inds))
  newdata = newdata[-constraint_inds,]
  
  preds = predict(object, data = newdata)$predictions
  return(preds)
}

I think pfun_2 is what your suggesting with your comment about filtering newdata first. When I try to run it, though, I get:

> set.seed(5038)
> shap <- explain(rfo, X = X, pred_wrapper = pfun, nsim = 10)
Error: sample_fraction too small, no observations sampled. Ranger will EXIT now.
Error in predict.ranger.forest(forest, data, predict.all, num.trees, type,  : 
  User interrupt or internal error.
In addition: There were 44 warnings (use warnings() to see them)
Called from: predict.ranger.forest(forest, data, predict.all, num.trees, type, 
    se.method, seed, num.threads, verbose, object$inbag.counts, 
    ...)
Browse[1]> 

When I run pfun_1 , the explain function runs without error but only returns a single result.

> set.seed(5038)
> shap <- explain(rfo, X = X, pred_wrapper = pfun, nsim = 10)
> shap[complete.cases(shap), ]#
# A tibble: 1 × 10
      x1    x2     x3    x4     x5     x6      x7      x8     x9    x10
   <dbl> <dbl>  <dbl> <dbl>  <dbl>  <dbl>   <dbl>   <dbl>  <dbl>  <dbl>
1 -0.107 -1.60 -0.363  2.10 -0.576 -0.110 -0.0462 -0.0634 0.0209 -0.108

I guess I'm asking if there is a way of using this package to calculate Shapley values when certain combinations of covariates cannot coexist (i.e. constraints on the data). For example, if I'm estimating home value and two of my inputs variables are 'Type of Home' and 'Number of Floors.' Within that dataset, there may be 'Ranch' as a type of home and there may be '2' as the number of floors. But I cannot make a prediction on a ranch-style house with two floors because that combination of features does not exist in the real world and wouldn't exist in the training data. How would I reflect that constraint in this package construct?

Thank you, and apologies for the long-winded reply.
Jon

Hi @JonP-16, I see what you're saying. This is one of the unfortunate drawbacks of methods that use permutations or background data sets. The only feasible solution I can see (and I'd have to check on how theoretically justified this is) would be to use a more specific background data set for the X argument; it is from this data frame that explain() draws random instances from. For example, if you were trying to explain predictions for ranch-style homes, then pass in a suitably sized background data set to X consisting of only ranch-style homes, at least this way you'd less likely end up with unrealistic observations (i.e., combinations of features). It's not a perfect solution, but seems closer to what you're trying to accomplish.

Yeah, your point about the permutation based approach was what I was thinking would be the issue. I suspect it would be a descent sized effort to add in checks for the constraints.

Regarding your point about a specific set for X, that is what I had done back at the opening of this issue:

trn <- gen_friedman(250, seed = 101)
constraint_1_inds = which(trn$x1 < 0.35 & trn$x5 > 0.85)
constraint_2_inds = which(trn$x3 < 0.25 & trn$x9 > 0.75)
constraint_inds   = unique(c(constraint_1_inds, constraint_2_inds))
trn <- trn[-constraint_inds, ]
dim(trn) # lost 33 data points; now at 217 instead of 250

X <- subset(trn, select = -y)  # feature columns only

The constrained regions were removed from the X dataset. But I see your point that this might be remedied from simply having more data. Thanks for following up!