bgreenwell/fastshap

`explain` fails when variables are removed as part of `recipes::recipe` preprocessing

mevers opened this issue · 2 comments

Reprex:

library(tidyverse)
library(tidymodels)
library(fastshap)

# Sample data: `mtcars` with 50% of the entries in `qsec` replaced with `NA`s
set.seed(2022)
data <- mtcars %>%
    select(-c(vs, am)) %>%
    mutate(qsec = replace(
        qsec, sample.int(nrow(mtcars), size = nrow(mtcars) / 2), NA_real_))

recipe <- recipe(mpg ~ ., data = data) %>%
    # Remove variables with more than 30% of missing data; this will be `qsec`
    step_filter_missing(all_numeric_predictors(), threshold = 0.3)

# We can confirm that `qsec` has been removed after pre-processing
# recipe %>% prep() %>% bake(new_data = NULL)

# Define & fit model to data
spec <- linear_reg() %>% set_engine("glm")
fitted_model <- workflow() %>%
    add_recipe(recipe) %>%
    add_model(spec) %>%
    fit(data = data)

# FastSHAP
fshap <- explain(
    fitted_model,
    X = recipe %>% prep() %>% bake(new_data = NULL) %>% select(-c(mpg)),
    pred_wrapper = function(model, newdata) predict(model, newdata)$.pred,
    shap_only = FALSE)

This throws an error

Error in { :
task 1 failed - "The following required columns are missing: 'qsec'."

This is because fitted_model retains a reference to qsec even though the variable was removed during pre-processing in recipe.

Question: What is the canonical way to supply X here? I could reference data directly

fshap <- explain(
    fitted_model,
    X = data %>% select(-c(mpg)),
    pred_wrapper = function(model, newdata) predict(model, newdata)$.pred,
    shap_only = FALSE)

but (1) this doesn't seem to be very tidymodels-canonical, and (2) this then includes qsec in the SHAP analysis (which it shouldn't). A fix to that issue would be to use the feature_names argument to exclude qsec, but this seems unnecessarily complicated.

What is the fastshap-way to provide X via a recipe?

Hi @mevers, I might be missing something here, but it seems to fail because your prediction wrapper fails:

#
# Test out prediction wrapper
#
X <- recipe %>% prep() %>% bake(new_data = NULL) %>% select(-c(mpg))
predict(fitted_model, new_data = X)
# Error in `validate_column_names()`:
#   ! The following required columns are missing: 'qsec'.
# Run `rlang::last_error()` to see where the error occurred.

There's no qsec term in the underlying fit, so maybe cross-list this question with the hardhat and/or workflows repos?