Implementation of the paper Measuring the Mixing of Contextual Information in the Transformer
The Transformer architecture aggregates input information through the self-attention mechanism, but there is no clear understanding of how this information is mixed across the entire model. Additionally, recent works have demonstrated that attention weights alone are not enough to describe the flow of information. In this paper, we consider the whole attention block --multi-head attention, residual connection, and layer normalization-- and define a metric to measure token-to-token interactions within each layer, considering the characteristics of the representation space. Then, we aggregate layer-wise interpretations to provide input attribution scores for model predictions. Experimentally, we show that our method, ALTI (Aggregation of Layer-wise Token-to-token Interactions), provides faithful explanations and outperforms similar aggregation methods.
Clone this repostitory to $CONTRIB_ROOT
:
!git clone https://github.com/javiferran/transformer_contributions.git ${CONTRIB_ROOT}
pip install -r ${CONTRIB_ROOT}/requirements.txt
In our paper we use BERT, DistilBERT and RoBERTa models from Huggingface's transformers library, but it can be easily extended to other models.
We compare our method with:
- Attention Rollout (Abnar and Zuidema., 2020)
- Attention Rollout + (Kobayashi et al., 2021)
- Gradient-based methods: Gradient Saliency, Integrated Gradients and Gradient x Input
We use Captum implementation of gradient-based methods.
To reproduce Table 2, Figure 6 and 7, first run the following file with the different models and datasets:
python ${CONTRIB_ROOT}/correlations.py \
-model bert \ # model: bert/distilbert/roberta
-dataset sst2 \ # dataset to use: sst2/sva
-samples 500 \ # number of samples
To analyze model predictions with the proposed (and others) intepretability methods in SST2 dataset:
Text_classification.ipynb
To analyze model predictions with the proposed (and others) intepretability methods in Subject-Verb Agreement dataset:
SVA.ipynb