how to get Shap interactions for LightGBM?
pecto2020 opened this issue · 6 comments
Your package is great, and very easy to use within tidymodels framework. I was wondering if it is possible to calculate interactions for LightGBM. I would like to use that instead of the heuristic (which is an amazing solution tho) in sv_dependence. I've seen that for Xgboost is possible and there is a param Interaction = T to set in shapviz.base. Any solution workaround for LightGBM?
Unfortunately not via TreeSHAP in LightGBM. But you could crunch interactions via the {treeshap} package.
I assume it involves hacking C++ code, which I can't help with :/
Oh, hmm...
Tried to use treeshap but got an error #> Error in S_inter[, v, color_var]: subscript out of bounds
Here's the code
library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.2.3
#> Warning: package 'broom' was built under R version 4.2.3
#> Warning: package 'dials' was built under R version 4.2.3
#> Warning: package 'dplyr' was built under R version 4.2.3
#> Warning: package 'ggplot2' was built under R version 4.2.3
#> Warning: package 'parsnip' was built under R version 4.2.3
#> Warning: package 'recipes' was built under R version 4.2.3
#> Warning: package 'tibble' was built under R version 4.2.3
#> Warning: package 'tidyr' was built under R version 4.2.3
#> Warning: package 'tune' was built under R version 4.2.3
#> Warning: package 'workflowsets' was built under R version 4.2.3
#> Warning: package 'yardstick' was built under R version 4.2.3
library(shapviz)
#> Warning: package 'shapviz' was built under R version 4.2.3
library(treeshap)
library(lightgbm)
#> Warning: package 'lightgbm' was built under R version 4.2.3
#> Loading required package: R6
#>
#> Attaching package: 'lightgbm'
#> The following object is masked from 'package:dplyr':
#>
#> slice
library(datasets)
library(bonsai)
#> Warning: package 'bonsai' was built under R version 4.2.3
# Use the fifa20 dataset
fifa20 <- fifa20$data %>%
select(-work_rate) %>%
bind_cols(data.frame(target = fifa20$target))
# Split the data
set.seed(123)
split <- initial_split(fifa20)
train <- training(split)
test <- testing(split)
# Recipe
rec <- recipe(target ~ ., data = train)
# Model specification
boost_spec <- boost_tree(
mode = "regression",
trees = 200,
tree_depth = 6
) %>%
set_engine("lightgbm") %>%
set_mode("regression")
# Workflow
workflow <- workflow() %>%
add_recipe(rec) %>%
add_model(boost_spec)
# Fit the model
boost_model <- workflow %>% fit(data = train)
# Create shap object with shapviz
shap_lgbm <- shapviz(extract_fit_engine(boost_model),
as.matrix(test %>% select(-target)),
test %>% select(-target))
# Create unified model representation
unified_lgbm <- treeshap::lightgbm.unify(extract_fit_engine(boost_model), train %>% select(-target))
# Derive interactions
interactions_lgbm <- treeshap::treeshap(unified_lgbm, test %>% select(-target), interactions = T, verbose = 0)
# Plot dependences
shap_lgbm$S_inter <- interactions_lgbm$interactions
sv_dependence(shap_lgbm, v = "overall", interactions = T, color_var = "height_cm")
#> Error in S_inter[, v, color_var]: subscript out of bounds
dim(shap_lgbm$S_inter)
#> [1] 54 54 4570
An interaction cannot be assigned to a shapviz object, so this code here is wrong:
shap_lgbm$S_inter <- interactions_lgbm$interactions
This works, but I would decompose less rows and divide the response by 1e6 (or so):
shap_lgbm <- shapviz(interactions_lgbm)
top4 <- names(head(sv_importance(shap_lgbm, kind = "no"), 4))
sv_interaction(shap_lgbm[1:1000, top4])
sv_dependence(shap_lgbm, v = "overall", color_var = top4, interactions = TRUE)