Interpretability Methods

This repository contains implementations of common attribution-based interpretability methods. See the example notebook!

Current methods that are implemented:

Each method performs batched computation and can be computed with and without SmoothGrad. The methods are implemented using Captum and puplic repostiories (i.e., LIME) and are largley inspired by the Google Pair Saliency implementation.

Set Up

Clone this repository. Then

# Install the requirements
pip install -r requirements.txt

# Install the package locally
pip install -e /path/to/interpretability_methods

Usage

See notebook for examples.

Each saliency method (i.e., VanillaGradients) extends the base class SaliencyMethod. Each method is instantiated with a model and, optionally, other method specific parameters. An SaliencyMethod object has two public methods: get_saliency and get_saliency_smoothed.

Usage example:

# Getting Vanilla Gradients with respect to the predicted class.
from interpretability_methods.vanilla_gradients import VanillaGradients
from interpretability_methods.util import visualize_saliency

model = ... # assuming pytorch model 
input_batch = ... # assumping 4D input batch (batch, channels, height, width)
vanilla_gradients_method = VanillaGradients(model)
vanilla_gradients = vanilla_gradients_method(input_batch) # attributions of shape (batch, channels, height, width)
visualize_saliency(vanilla_gradients) # will output greyscale saliency image