karpathy/llama2.c

Interpretability of models

jrudolph opened this issue · 4 comments

I started looking into whether the small models could be a target for interpretability. I'm just putting it out here in lack of a better space to find people wanting to help.

Here is a list of seemingly basic tasks that would be nice to get some insight into:

  • Syntactical level:
    • How are (multi-token) words represented?
      • Given a prefix how is the rest of the word generated?
    • How are different word classes represented?
      • nouns
      • verbs
      • adjectives
      • adverbs
    • How are different word forms represented?
      • singular
      • plural
      • past tense
      • present tense
  • Grammatical
    • How is is determined what kind of word should come next?
    • How are references to previous word represented and how are they suggested for output?

I built a small GUI to start looking into the stories15M model. One first thing I found is that layer 5 head 1 (counting from zero), attends to previous references of a word that should come next.

E.g. in "Once upon a time, there was a little girl called Sarah. She was three years old and always loved to eat spaghetti. Every day, she went to the kitchen to have a bowl of" here is how the attention looks:

image

Another thing I found is that for structurally equivalent prompts the attention looks almost exactly the same (which is what you would expect), at least if the subject has the same species and gender (compare girl vs boy vs dog).

The GUI tries to give an indication of how output tokens are selected from the classifier by showing a map of how the output token representation is matched against the classifier (first line is the classifier, other lines are the logit components (classifier X wcls) before summing, green means positive contribution to logit, red negative):

image

I haven't read to much theory about this, but as I interpret the last step of the model, output tokens are matched by dot-product (~ cos similarity?) with the classification weights. In the above example, interestingly the top-3 completions are all different classes of words, with the model trying to suggest "[sp]aghetti", "[y]ummy", and "her". Pretty interesting how these different classes can be selected from a single classifier.

There are some prominent dimensions (though, I'm not sure it makes sense at all to read too much into these naive interpretations of these high-dimensional latent spaces, because only features that happen to align with one of the axes of the space will be shown prominently, while other features that go into the direction of multiple axes cannot be seen as they would sum smaller contributions from multiple dimensions. So probably something like PCA would make sense to reveal more interesting things).

One thought: Assuming that head 5/1 selects a reference to a previous word (or noun?). Given that we have multiple heads, the output dimension of that attention head is only 1/6th of the full dimension. How is the model able to recover the exact token representation from that partial information? (Giving myself an answer: the latent spaces are much bigger (288 dims * fp32: 9216 bits or 1/6 for a single head = 1536 bits) than what is needed to represent one of 32000 tokens (~= 15 bits), so there is lots of room to organize things smartly.) One follow up reverse engineering task could be see how the last layer FFN takes the information from head 5/1 and generates the classifier from it. Questions: 1. what is the flag that tells the model to select a previously seen token in that step, and 2. how is the referenced token reproduced from the limited information coming out of head 5/1.

A related observation is that replacing names (proper nouns) with unknown (to the model) words, referencing sometimes still works, but it can derail generation in the longer term. One reason could be that the last-layer FFN has to learn a mapping from the smaller representation coming out of head 5/1 to the full representation of tokens but can only do so reliably when it has seen something similar during training (e.g. has inferred that some word is a proper name and has to be referenced literally).

janimo commented

The Tinystories paper has a section dedicated to interpretability, as it is one of their motivations/findings

https://arxiv.org/abs/2305.07759

The Tinystories paper has a section dedicated to interpretability, as it is one of their motivations/findings

Thanks for sharing, interesting, indeed.

A related observation is that replacing names (proper nouns) with unknown (to the model) words, referencing sometimes still works, but it can derail generation in the longer term. One reason could be that the last-layer FFN has to learn a mapping from the smaller representation coming out of head 5/1 to the full representation of tokens but can only do so reliably when it has seen something similar during training (e.g. has inferred that some word is a proper name and has to be referenced literally).

Some preliminary experiments on how well the tinyllama models can refer to uncommon or unknown names, is that it indeed does not work in general. The 110M model often is sometimes able to recount unknown 2-token names, but the smaller ones often are not able to do that.

E.g. compare

The attention pattern looks quite similar, so it seems the model makes an effort also in the latter cases to infer the correct suffix, however, the output distribution is very flat, with the right token in the latter cases being found only at ranks >1000.

One naive hypothesis about this is that referring to previous names requires different both attention to a previous instance of the name and the suffix being a likely completion more generally (e.g. in a ngram-sense)

Rethinking the Role of Scale for In-Context Learning: An Interpretability-based Case Study at 66 Billion Scale also investigates how suffixes might be copied after a prefix of a previously occurred word has been seen.