An interpretability library for pytorch
Explore the docs »
View Examples
·
Report Bug
·
Request Feature
Table of Contents
This is in active development so the README might be outdated and is not listing all things currently implemented. See the dev branch for more information.
toumei is a little sideproject of mine, trying to combine state of the art interpretability and model editing methods into a pythonic library. The goal is to compile useful methods into a coherent toolchain and make complexe methods accessible using a intuitive syntax.
I think interpretability methods became quite powerful and therefore useful in the last couple years, wanting me to provide a library for broader use of these methods.
Following methods are currently or will be implemented:
- Feature Visualization (1)
- various image parameterization methods (Pixel, FFT, CPPN, etc.)
- transformation robustness and total variance regularization
- custom objective building
- joint optimization
- activation difference optimization (for e. g. styletransfer)
- Causal Tracing and Rank-One model editing (1)
- causal tracing for huggingface like transformer objects
- rank-one model editing (WIP)
- Unified Feature Attribution (1)
- LIME (WIP)
- DeepLift (planned)
- SHAP methods (planned)
- Circuit detection using feature atribution (research idea)
- Model Modularity
- Spectral Clustering for model graphs
- Implement the network modularity metric
- Measuring modularity of MLPs
- Measuring modularity of CNNs
- Investigate (randomly) modulary varying goals in modern deep learning architectures (research project)
- Engineering a baseline model with high modularity
- Comparing the baseline model against other models using the modularity metric
- Measure the impact of modulary varying goals on model modularity
I am planning to add new things as I learn about them in the future, so this project basically mirrors my progress in the field of AI Interpretability.
toumei can not be installed using pip
. To use toumei by running the experiments or adding it to your projects, please follow the guide below.
Make sure the following libraries are installed or install them using
pip install torch torchvision tqdm matplotlib transformers seaborn scikit-learn networkx
- Clone the repo
git clone https://github.com/LuanAdemi/toumei.git
- Run the experiments
cd toumei/experiments python <experiment>.py
- Move the library to your project
cd .. cp toumei <path_to_your_project>
In order to perform feature visualization on a convolutional model we are going to need two things: a image parameterization method and an objective.
These are located in the toumei.cnn
package.
import torch
import torchvision.transforms as T
# import toumei
import toumei.cnns.objectives as obj
import toumei.cnns.parameterization as param
Next, we are going to import a model we can perform feature visualization on
from toumei.models import Inception5h
# the model we want to analyze
model = Inception5h(pretrained=True)
To counter noise in the optimization process, we are going to define a transfrom function used to perfom transformation robustness regularization
# compose the image transformation for regularization through transformations robustness
transform = T.Compose([
T.Pad(12),
T.RandomRotation((-10, 11)),
T.Lambda(lambda x: x*255 - 117) # inception needs this
])
We are now able to define our objective pipeline using a image parameterization method (here FFT) and our objective (visualize unit mixed3a:74)
# define a feature visualization pipeline
fv = obj.Pipeline(
# the image generator object
param.Transform(param.FFTImage(1, 3, 224, 224), transform),
# the objective function
obj.Channel("mixed3a:74")
)
Finally, we are going to optimize our pipeline and plot the results
# attach the pipeline to the model
fv.attach(model)
# send the objective to the gpu
fv.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# optimize the objective
fv.optimize()
# plot the results
fv.plot()
If we want to locate factual knowledge in GPT like models, we can use causal tracing. toumei implements this in the toumei.transformers.rome
package.
from toumei.transformers.rome.tracing import CausalTracer
This will import everything we need to perform causal tracing. Using huggingfaces transformers library we can easily get a model we can perform causal tracing on
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# load gpt2 from huggingface
model = AutoModelForCausalLM.from_pretrained("gpt2-xl", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
After defining a prompt and specifying the subject, we can create a CausalTracer object and trace the model using the prompt
# specify a prompt and it's subject for causal tracing
prompt = "Karlsruhe Institute of Technology is located in the country of"
subject = "Karlsruhe Institute of Technology"
# perform causal tracing
tracer = CausalTracer(model, tokenizer)
tracer.trace(prompt, subject, verbose=True)
toumei implements the modularity metric derived from graph theory for common deep learning architectures such as MLPs and CNNs.
We start by converting our model to an actual graph, we can perform graph algorithms on. toumei provides different wrappers for every architecture, which can be imported from the misc
package
from toumei.misc import MLPGraph, CNNGraph
Next we import and initialize some models, of which we want to measure the modularity of.
from toumei.models import SimpleMLP, SimpleCNN
# create the models
mlp = SimpleMLP(4, 4)
cnn = SimpleCNN(1, 10)
Wrapping these models with the imported classes builds the corresponding weighted graph of the model
# create graph from the model
mlp_graph = MLPGraph(mlp)
cnn_graph = CNNGraph(cnn)
This wrapper allows us to perform all sorts of graph algorithms on it. We can get the modularity of the graph by performing spectral clustering on it to partition the graph in
This is all done internally by calling
# calculate the modularity
print(mlp_graph.get_model_modularity())
print(cnn_graph.get_model_modularity())
See the experiments folder for more examples
You are more than welcome to contribute to this project or propose new interpretability methods I can add. Just open an issue or pull request, like you would do on any other github repo.
Distributed under the GPL-3.0 License. See LICENSE.txt
for more information.
Luan Ademi - luan.ademi@student.kit.edu
Project Link: https://github.com/LuanAdemi/toumei
The following section lists resources I recommend / used myself for building this project.
- Interpretability in ML: A Broad Overview
- Transparancy and AGI safety
- A transparency and interpretability tech tree
- A Unified Approach to Interpreting Model Predictions
- Visualizing the Impact of Feature Attribution Baselines
- Repo: slundberg/shap