shap/shap

BUG: Unexpected Interaction Plot Instead of Summary Plot in Multiclass SHAP Summary with XGBoost

Closed this issue · 7 comments

Issue Description

When attempting to use SHAP with an XGBoost multiclass classification model to generate summary plots, the output unexpectedly appears as an interaction plot rather than the anticipated summary plot. This issue occurs when trying to visualize the SHAP values for all classes simultaneously.

Minimal Reproducible Example

import xgboost as xgb
import shap
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score

# Generate synthetic data
X, y = make_classification(n_samples=500, n_features=20, n_informative=4, n_classes=6, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

# Train an XGBoost model for multiclass classification
model = xgb.XGBClassifier(objective="multi:softprob", random_state=42)
model.fit(X_train, y_train)

# Create a SHAP TreeExplainer
explainer = shap.TreeExplainer(model)

# Calculate SHAP values for the test set
shap_values = explainer.shap_values(X_test)

# Attempt to plot summary for all classes
shap.summary_plot(shap_values, X_test, plot_type="bar")

Traceback

No response

Expected Behavior

The expected outcome is a summary plot that shows the feature importance for all classes in a clear and aggregated manner.

Bug report checklist

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest release of shap.
  • I have confirmed this bug exists on the master branch of shap.
  • I'd be interested in making a PR to fix this bug

Installed Versions

SHAP version: 0.45.0
Python version: 3.10.12
XGBoost version: 2.0.3
Operating System: Google Colab Pro

It is not XGBoost-specific, as I have the same problem with SHAP values derived from CatBoost and LightGBM models. It is related to shap.summary_plot.

I have encountered the same issue - with multiclass output, the summary_plot function generates interaction plot while the summary bar plot is expected.

I manually fixed this issue by going to their source code and change the data type of their TreeExplainer output from numpy array to list.

Here is what I did in detail: I went to https://github.com/shap/shap/blob/master/shap/explainers/_tree.py and commented lines 515-516. After that, I successfully generated the summary plot with multi-class output.

This error was due to the change in version 0.45.0 - they changed the output from list to numpy array, as can be seen in lines 410-411 of file https://github.com/shap/shap/blob/master/shap/explainers/_tree.py, so I reversed this change to fix the issue.

Well spotted! I think for the majority of cases, a shortcut with a C++ implementation of Tree SHAP is used, so these 2 lines need to be commented out too (the same data transformation as in the lines you pointed to):

https://github.com/shap/shap/blob/86d8bc58a42e9e11901ad506f5c27f55fa4f0349/shap/explainers/_tree.py#L478C1-L479C49

Commenting these lines out most likely has some side effects, but without these lines the SHAP summary plot indeed works for multi-class classification models. Thanks!

I encountered the same problem, and switching back to version 0.44.1 resolved it for me.

Below is a straightforward code to demonstrate the issue:

# Create a synthetic dataset
X, y = make_classification(n_samples=100, n_features=5, n_informative=3, n_redundant=1, n_clusters_per_class=1, n_classes=3, random_state=42)
features = [f"Feature {i}" for i in range(X.shape[1])]
X = pd.DataFrame(X, columns=features)

# Train a RandomForest model
model = RandomForestClassifier(n_estimators=50, random_state=42)
model.fit(X, y)

# Create the SHAP Explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

# Plot SHAP values for each class
shap.summary_plot(shap_values, X, plot_type="bar", class_names=['Class 0', 'Class 1', 'Class 2'])

Here are the screenshots for both versions:

Screenshot 2024-06-10 at 11 15 50 AM copy
Screenshot 2024-06-10 at 11 32 27 AM copy

@Omranic switching back to version 0.44.1 was the solution I went for myself. Thank you guys for responding tot this issue!

I agree with @mengwang-mw that the issue was the change from using lists to numpy arrays introduced in this PR. The summary legacy plot is breaking as it is still looking for lists. The real fix is to have this line look instead for a numpy 3 dimensional array. I've got a PR to address this here.

To fix in the interim without changing library code, you can simply change the returned shap values to a list. E.g.:

shap_value_summary = explainer.shap_values(feature_train)
ensured_list_shap_values = [shap_value_summary[:,:,i] for i in range(shap_value_summary.shape[2])]
shap.summary_plot(ensured_list_shap_values)

PR #3836 addresses the stated problem for Explanation objects.

For those encountering this issue, please 1) update to the latest version (>= v0.47) when it is released / use the latest version on master branch, 2) use the new Explanation API and pass that object into shap.summary_plot instead.

Demonstrated below:

# Create the SHAP Explainer
explainer = shap.TreeExplainer(model)
explanation = explainer(X)  # instead of explainer.shap_values(X) <<<<<<<<

# Plot SHAP values for each class
shap.summary_plot(explanation, plot_type="bar")