/MIA

Mechanistic Interpretability of Attention

Primary LanguagePython

MIA

Mechanistic Interpretability of Attention

This project aims to understand the role of different attention heads in establishing relationships between texts. Both as an exercise in mechanistic interpretability, and because I think there's certain useful relationships between fragments of texts that can be computed this way.

Each attention head from each layer in an LLM is assigned a positive or negative weight to be predictive of a specific relationship between tokens. The result is a relationship detector that runs on token pairs in which all heads either contribute to or detract from a relationship being detected.

Right now, this project only solves for one relationship between tokens: verbatim reproduction of phrases. It trains a logistic regression model on attention scores to create the phrase-matching detector. This is distinct from doing a string search across both texts, because it benefits from a contextual understanding of tokens.

Currently, the training data is limited. It's derived from a set of paragraphs by pairing the original paragraph with a reshuffling of its sentences. Token-to-token attention scores from BART (num_heads * num_layers = 192 per pair) are then associated with a boolean that's true for corresponding tokens and false for all others. The model is trained on these inputs and outputs.

Next Steps

The major next step is to enrich the training data with more sophisticated examples. I'm interested in seeing how well this generalizes from verbatim matching to texts involving rephrasing.

This repo still needs cleaned up a bit, too. The training data is generated by uncommenting various function calls in this project (not easy to work with).

Setup

$ poetry install
$ poetry run python -m spacy download en
$ poetry run python -m mia