This repository contains implementations of common attribution-based interpretability methods. See the example notebook!
Current methods that are implemented:
- Vanilla Gradients (paper 1 | paper 2 | implemented via Captum)
- Input x Gradient (paper | implemented via Captum)
- Integrated Gradients (paper | implemented via Captum)
- SmoothGrad (paper | adapted from Google Pair Saliency)
- Guided Backprop (paper | implemented via Captum)
- GradCAM (paper | implemented via Captum)
- Gradient SHAP (paper | implemented via Captum)
- Kernel SHAP (paper | implemented via Captum)
- RISE (paper | implemented via the authors' GitHub)
- XRAI (paper | adapted from Google Pair Saliency)
- LIME (paper | implemented via the authors' GitHub)
- SIS (paper 1 | paper 2 | implemented via the authors' GitHub)
Each method performs batched computation and can be computed with and without SmoothGrad. The methods are implemented using Captum and puplic repostiories (i.e., LIME) and are largley inspired by the Google Pair Saliency implementation.
Clone this repository. Then
# Install the requirements
pip install -r requirements.txt
# Install the package locally
pip install -e /path/to/interpretability_methods
See notebook for examples.
Each saliency method (i.e., VanillaGradients
) extends the base class SaliencyMethod
. Each method is instantiated with a model and, optionally, other method specific parameters. An SaliencyMethod
object has two public methods: get_saliency
and get_saliency_smoothed
.
Usage example:
# Getting Vanilla Gradients with respect to the predicted class.
from interpretability_methods.vanilla_gradients import VanillaGradients
from interpretability_methods.util import visualize_saliency
model = ... # assuming pytorch model
input_batch = ... # assumping 4D input batch (batch, channels, height, width)
vanilla_gradients_method = VanillaGradients(model)
vanilla_gradients = vanilla_gradients_method(input_batch) # attributions of shape (batch, channels, height, width)
visualize_saliency(vanilla_gradients) # will output greyscale saliency image