`explain` fails when variables are removed as part of `recipes::recipe` preprocessing
mevers opened this issue · 2 comments
Reprex:
library(tidyverse)
library(tidymodels)
library(fastshap)
# Sample data: `mtcars` with 50% of the entries in `qsec` replaced with `NA`s
set.seed(2022)
data <- mtcars %>%
select(-c(vs, am)) %>%
mutate(qsec = replace(
qsec, sample.int(nrow(mtcars), size = nrow(mtcars) / 2), NA_real_))
recipe <- recipe(mpg ~ ., data = data) %>%
# Remove variables with more than 30% of missing data; this will be `qsec`
step_filter_missing(all_numeric_predictors(), threshold = 0.3)
# We can confirm that `qsec` has been removed after pre-processing
# recipe %>% prep() %>% bake(new_data = NULL)
# Define & fit model to data
spec <- linear_reg() %>% set_engine("glm")
fitted_model <- workflow() %>%
add_recipe(recipe) %>%
add_model(spec) %>%
fit(data = data)
# FastSHAP
fshap <- explain(
fitted_model,
X = recipe %>% prep() %>% bake(new_data = NULL) %>% select(-c(mpg)),
pred_wrapper = function(model, newdata) predict(model, newdata)$.pred,
shap_only = FALSE)
This throws an error
Error in { :
task 1 failed - "The following required columns are missing: 'qsec'."
This is because fitted_model
retains a reference to qsec
even though the variable was removed during pre-processing in recipe
.
Question: What is the canonical way to supply X
here? I could reference data
directly
fshap <- explain(
fitted_model,
X = data %>% select(-c(mpg)),
pred_wrapper = function(model, newdata) predict(model, newdata)$.pred,
shap_only = FALSE)
but (1) this doesn't seem to be very tidymodels
-canonical, and (2) this then includes qsec
in the SHAP analysis (which it shouldn't). A fix to that issue would be to use the feature_names
argument to exclude qsec
, but this seems unnecessarily complicated.
What is the fastshap
-way to provide X
via a recipe
?
Hi @mevers, I might be missing something here, but it seems to fail because your prediction wrapper fails:
#
# Test out prediction wrapper
#
X <- recipe %>% prep() %>% bake(new_data = NULL) %>% select(-c(mpg))
predict(fitted_model, new_data = X)
# Error in `validate_column_names()`:
# ! The following required columns are missing: 'qsec'.
# Run `rlang::last_error()` to see where the error occurred.
There's no qsec
term in the underlying fit, so maybe cross-list this question with the hardhat and/or workflows repos?