streamlit-shap
is a Streamlit component that provides a wrapper to display SHAP plots in Streamlit.
The library is developed by our in-house staff Snehan Kekre who also maintains the Streamlit Documentation website.
Firstly, install Streamlit (of course!) then pip install the streamlit-shap
library:
pip install streamlit
pip install streamlit-shap
There are also other prerequisite libraries to install (e.g. matplotlib
, pandas
, scikit-learn
and xgboost
) if you haven't yet done so.
Here's how to use streamlit-shap
:
import streamlit as st
from streamlit_shap import st_shap
import shap
from sklearn.model_selection import train_test_split
import xgboost
import numpy as np
import pandas as pd
st.set_page_config(layout="wide")
@st.experimental_memo
def load_data():
return shap.datasets.adult()
@st.experimental_memo
def load_model(X, y):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
d_train = xgboost.DMatrix(X_train, label=y_train)
d_test = xgboost.DMatrix(X_test, label=y_test)
params = {
"eta": 0.01,
"objective": "binary:logistic",
"subsample": 0.5,
"base_score": np.mean(y_train),
"eval_metric": "logloss",
"n_jobs": -1,
}
model = xgboost.train(params, d_train, 10, evals = [(d_test, "test")], verbose_eval=100, early_stopping_rounds=20)
return model
st.title("`streamlit-shap` for displaying SHAP plots in a Streamlit app")
with st.expander('About the app'):
st.markdown('''[`streamlit-shap`](https://github.com/snehankekre/streamlit-shap) is a Streamlit component that provides a wrapper to display [SHAP](https://github.com/slundberg/shap) plots in [Streamlit](https://streamlit.io/).
The library is developed by our in-house staff [Snehan Kekre](https://github.com/snehankekre) who also maintains the [Streamlit Documentation](https://docs.streamlit.io/) website.
''')
st.header('Input data')
X,y = load_data()
X_display,y_display = shap.datasets.adult(display=True)
with st.expander('About the data'):
st.write('Adult census data is used as the example dataset.')
with st.expander('X'):
st.dataframe(X)
with st.expander('y'):
st.dataframe(y)
st.header('SHAP output')
# train XGBoost model
model = load_model(X, y)
# compute SHAP values
explainer = shap.Explainer(model, X)
shap_values = explainer(X)
with st.expander('Waterfall plot'):
st_shap(shap.plots.waterfall(shap_values[0]), height=300)
with st.expander('Beeswarm plot'):
st_shap(shap.plots.beeswarm(shap_values), height=300)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
with st.expander('Force plot'):
st.subheader('First data instance')
st_shap(shap.force_plot(explainer.expected_value, shap_values[0,:], X_display.iloc[0,:]), height=200, width=1000)
st.subheader('First thousand data instance')
st_shap(shap.force_plot(explainer.expected_value, shap_values[:1000,:], X_display.iloc[:1000,:]), height=400, width=1000)
The very first thing to do when creating a Streamlit app is to start by importing the streamlit
library as st
like so:
import streamlit as st
from streamlit_shap import st_shap
import shap
from sklearn.model_selection import train_test_split
import xgboost
import numpy as np
import pandas as pd
Next, we'll set the page layout to be wide such that contents in the Streamlit app can spread the full page width.
st.set_page_config(layout="wide")
Then, we'll load in a dataset from the shap
library:
@st.experimental_memo
def load_data():
return shap.datasets.adult()
Subsequently, we'll definite a function called load_model
for taking in the X, y
matrix pair as input, perform data splitting to train/test sets, constructing a DMatrix
and build an XGBoost model.
@st.experimental_memo
def load_model(X, y):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
d_train = xgboost.DMatrix(X_train, label=y_train)
d_test = xgboost.DMatrix(X_test, label=y_test)
params = {
"eta": 0.01,
"objective": "binary:logistic",
"subsample": 0.5,
"base_score": np.mean(y_train),
"eval_metric": "logloss",
"n_jobs": -1,
}
model = xgboost.train(params, d_train, 10, evals = [(d_test, "test")], verbose_eval=100, early_stopping_rounds=20)
return model
The title of the Streamlit app is then displayed:
st.title("`streamlit-shap` for displaying SHAP plots in a Streamlit app")
An about expander box is implemented to provide details of the app:
with st.expander('About the app'):
st.markdown('''[`streamlit-shap`](https://github.com/snehankekre/streamlit-shap) is a Streamlit component that provides a wrapper to display [SHAP](https://github.com/slundberg/shap) plots in [Streamlit](https://streamlit.io/).
The library is developed by our in-house staff [Snehan Kekre](https://github.com/snehankekre) who also maintains the [Streamlit Documentation](https://docs.streamlit.io/) website.
''')
Here, we'll display the header text along with expander box of the X
and y
variables of the Input data:
st.header('Input data')
X,y = load_data()
X_display,y_display = shap.datasets.adult(display=True)
with st.expander('About the data'):
st.write('Adult census data is used as the example dataset.')
with st.expander('X'):
st.dataframe(X)
with st.expander('y'):
st.dataframe(y)
Here, we'll display the header text for the forthcoming SHAP output:
st.header('SHAP output')
The XGBoost model is then built by using the load_model
function that was just implemented above. Finally,
# train XGBoost model
X,y = load_data()
X_display,y_display = shap.datasets.adult(display=True)
model = load_model(X, y)
Here, we'll compute the SHAP values, which are then used to create the Waterfall and Beeswarm plots.
# compute SHAP values
explainer = shap.Explainer(model, X)
shap_values = explainer(X)
with st.expander('Waterfall plot'):
st_shap(shap.plots.waterfall(shap_values[0]), height=300)
with st.expander('Beeswarm plot'):
st_shap(shap.plots.beeswarm(shap_values), height=300)
Finally, the Tree SHAP algorithms is used to explain the output of ensemble tree models via the shap.TreeExplainer
command and visualized via the shap.force_plot
command:
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
with st.expander('Force plot'):
st.subheader('First data instance')
st_shap(shap.force_plot(explainer.expected_value, shap_values[0,:], X_display.iloc[0,:]), height=200, width=1000)
st.subheader('First thousand data instance')
st_shap(shap.force_plot(explainer.expected_value, shap_values[:1000,:], X_display.iloc[:1000,:]), height=400, width=1000)