/GAtt

Source code for the GAtt method in "Revisiting Attention Weights as Interpretations of Message-Passing Neural Networks".

Primary LanguageJupyter NotebookMIT LicenseMIT

Source code for GAtt

This repository is the official implementation of Revisiting Attention Weights as Interpretations of Message-Passing Neural Networks. In the paper, we show that GAtt provides a better way to calculate edge attribution scores from attention weights in attention-based GNNs!

We have a Dockerfile for running locally:

Just clone the repo, and build the docker image by:

docker build <directory_to_cloned_repo> --tags <your:tags>

If you don't want Docker...

The code is tested in...

  • Python 3.10.13
  • Pytorch 2.0.1+cu117
  • Pytorch Geometric 2.3.1

which should be enough to run the demos.

Provided in this repo are...

  1. Source code for GAtt
  2. Demos
  • Demo on the Cora dataset on how to use the get_gatt and the get_gatt_batch function
  • Demo on the BAShapes (generated from torch_geometric.datasets.ExplainerDataset): Visualizations of GAtt and comparison to AvgAtt
  • Demo on the Infection dataset (generated from the code in the original authors' repo): Visualizations of GAtt and comparison to AvgAtt

Results for Infection dataset

This is one of the results in the demo notebooks:

Figure (left to right)

  • Ground truth explanation (blue edges) for the target node (orange node)
  • Edge attribution from GAtt
  • Edge attribution from AvgAtt (averaging over the layers)

The figures show that the edge attribution scores in GAtt is more aligned with the ground truth explanation edges compared to just averaging over the GAT layers.