ModelOriented/DALEXtra

Variable importance plot lists only original variables, not one-hot-encoded variables, when permutation importance calculated with `model_parts()` and `explain_tidymodels()` functions?

kransom14 opened this issue · 1 comments

I am using explain_tidymodels() to compute variable importance. I have a workflow which includes a recipe with a step_dummy() step. I'm trying to understand why the associated variable importance calculated with model_parts() is given for the original variables rather than the one-hot-encoded variables when this step is included. Is the permutation importance aggregated at some point for the group of one-hot-encoded variables that go together? I didn't see this explained in the documentation. Reprex below. Please advise, Thank you

library("DALEXtra")
library("tidymodels")
library("recipes")

# example with no dummy variables
data <- titanic_imputed

data$survived <- as.factor(data$survived)

rec <- recipe(survived ~ ., data = data) %>%
  step_normalize(fare)

model <- decision_tree(tree_depth = 25) %>%
  set_engine("rpart") %>%
  set_mode("classification")

wflow <- workflow() %>%
  add_recipe(rec) %>%
  add_model(model)

model_fitted <- wflow %>%
  fit(data = data)

explainTest <- explain_tidymodels(model_fitted, data = data, y = as.numeric(data$survived))
explainModelParts <- model_parts(explainTest, type="variable_importance")
plot(explainModelParts)


# example with dummy variables
data <- titanic_imputed

data$survived <- as.factor(data$survived)

rec <- recipe(survived ~ ., data = data) %>%
  step_dummy(gender, class, embarked, one_hot = TRUE) %>% # one hot encode the categorical variables
  step_normalize(fare)

model <- decision_tree(tree_depth = 25) %>%
  set_engine("rpart") %>%
  set_mode("classification")

wflow <- workflow() %>%
  add_recipe(rec) %>%
  add_model(model)

model_fitted <- wflow %>%
  fit(data = data)

explainModel <- explain_tidymodels(model_fitted, data = data, y = as.numeric(data$survived))

vipData <- model_parts(explainModel, type = "variable_importance")
plot(vipData) # this plot shows original variable names and does not include the one hot encoded variables

It seems obvious that the importance is calculated on the features before preprocessing. Especially with permutation importance, permuting dummy columns is non-sensical as suddenly two dummies would be 1 within observation.