bgreenwell/fastshap

Example of Shap "summary" plot?

jcpsantiago opened this issue · 12 comments

I really like the summary plot created by the shap python package. Do you have an example snippet for creating it starting from fastshap::explain e.g. for XGBoost?

For reference, this is what I mean:

from https://github.com/slundberg/shap

This might be the answer I'm looking for # 1. Add \code{type = "beeswarm"}. in

# 1. Add \code{type = "beeswarm"}.

@jcpsantiago I’ll make sure this gets added in to the next update.

Here's a snippet from the code I use to create a similar chart. Could be useful.

top_10_feat <- shap %>%
  pivot_longer(everything()) %>%
  group_by(name) %>%
  summarise(v = mean(abs(value))) %>%
  arrange(desc(v)) %>%
  head(10) %>%
  pull(name)

df <- shap %>%
  rename_with(~ paste0(.x, "_shap")) %>%
  pivot_longer(everything(), names_to = "shap_keys", values_to = "shap_values") %>%
  bind_cols(baked_training %>% 
              select(is_fraud, all_of(names(shap))) %>% 
              pivot_longer(2:ncol(.))
            ) %>%
  select(name, value, shap_values, is_fraud) %>%
  filter(!is.na(value) & !is.na(shap_values)) %>%
  group_by(name) %>%
  mutate(scaled = value / max(value)) %>%
  ungroup() %>%
  mutate(name = forcats::fct_reorder(
    as.factor(name),
    .x = shap_values, .fun = function(x) max(abs(x))
  ))

df %>%
  filter(name %in% top_10_feat) %>%
  ggplot(aes(x = name, y = shap_values, color = scaled)) +
  ggforce::geom_sina(alpha = 0.3) +
  coord_flip() +
  scale_colour_viridis_c() +
  hrbrthemes::theme_ipsum_ps() +
  guides(color = guide_colourbar(
    title = "Scaled feature value",
    barwidth = 20, barheight = 0.5, title.position = "top"
  )) +
  theme(legend.position = "bottom") +
  labs(title = "SHAP values for top 10 features", x = "", y = "")

is there already an update to this? It would be really great to have this kind of plot in fastshap :)

@jotech I've been using the snippet I shared ☝️ in our model cards for the weekly deployments. It's not a single function, but it works :)

The trick is using reticulate to access the function directly. Also, if you look at issues on Shap, it would seem like matplotlib 3.2.2. is necessary.

Minimal example, taken from docs.

# Load required packages
library(fastshap)
library(xgboost)
# Load the Boston housing data
# install.packages("pdp)

data(boston, package = "pdp")
X <- data.matrix(subset(boston, select = -cmedv))  # matrix of feature values

# Fit a gradient boosted regression tree ensemble; hyperparameters were tuned 
# using `autoxgb::autoxgb()`
set.seed(859)  # for reproducibility
bst <- xgboost(X, label = boston$cmedv, nrounds = 338, max_depth = 3, eta = 0.1,
               verbose = 0)

# Compute exact explanations for all rows
ex <- explain(bst, exact = TRUE, X = X)

Next, use reticulate

library(reticulate)
shap = import("shap")
np = import("numpy")


shap$dependence_plot("rank(1)", data.matrix(ex), X)
shap$summary_plot(data.matrix(ex), X)

Naming the feature directly threw an error for me: i.e. rank(1) is necessary. Rank(2) and rank(3) would also work. Rendering the plot repeatedly will produce buggy visualizations (that was my experience at least.)

I've implemented a ggplot2-based beeswarm plot for fastshap's autoplot. You can try it out with the following code:

remotes::install_github("kapsner/fastshap", ref = "feat_beeswarm_plot")

library(doParallel)
library(fastshap)
library(ggplot2)
library(ranger)

boston <- pdp::boston
boston$chas <- as.integer(boston$chas) - 1
X <- data.matrix(subset(boston, select = -cmedv))

# Train a random forest
set.seed(944)  # for reproducibility
rfo <- ranger(cmedv ~ ., data = boston)
 
# Prediction wrappers
pfun <- function(object, newdata) {
  predict(object, data = newdata)$predictions
}

# Comput approximate Shapley values
set.seed(945)
system.time(
  shap <- fastshap::explain(rfo, X = X, nsim = 50, pred_wrapper = pfun)
)

p <- autoplot(object = shap, type = "beeswarm", X = boston)
p

beeswarm_plot

This is really awesome @kapsner, I’ll take a look!

FYI the just released R package shapviz also provides a ggplot-based beeswarm plot and other R-native visualizations for shapley values.

@kapsner, does this support fastshap too? I like this idea and have considered moving all the plotting functions, which have a lot more dependencies, into a new package away from the core functionality.

Yes, they also mention how to create plots from fastshap objects in their documentation.

Closing as this is supported in the shapviz project.