/Gradformer

Code for IJCAI'24 paper: Gradformer: Graph Transformer with Exponential Decay

Primary LanguagePython

Gradformer: Graph Transformer with Exponential Decay

Implementation for IJCAI'24 paper: Gradformer: Graph Transformer with Exponential Decay

Overview

image

This paper presents Gradformer, a method innovatively integrating GT with the intrinsic inductive bias by applying an exponential decay mask to the attention matrix. Specifically, the values in the decay mask matrix diminish exponentially, correlating with the decreasing node proximities within the graph structure. This design enables Gradformer to retain its ability to capture information from distant nodes while focusing on the graph's local details. Furthermore, Gradformer introduces a learnable constraint into the decay mask, allowing different attention heads to learn distinct decay masks. Such an design diversifies the attention heads, enabling a more effective assimilation of diverse structural information within the graph.

Python environment setup with Conda

conda create -n gradformer python=3.11
conda activate gradformer
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia
conda install pyg -c pyg
pip install rdkit-pypi cython
pip install ogb
pip install configargparse
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu118.html

Running Gradformer

conda activate gradformer
# Running Gradformer tuned hyperparameters for TuDataset.
sh ./scripts/run_nci1.sh 
# Running Gradformer tuned hyperparameters for ogbg-molhiv.
sh ./scripts/run_hiv.sh 
# Running Gradformer tuned hyperparameters for CLUSTER.
sh ./scripts/run_cluster.sh 

Supported datasets:

  • TuDataset: NCI1, PROTEINS, MUTAG, IMDB-BINARY, COLLAB
  • GNN Benchmarking: ZINC, CLUSTER, PATTERN
  • OGB: ogbg-molhiv

Baselines

Datasets

Datasets mentioned above will be downloaded automatically using PyG's API when running the code.

Gradformer is built using PyG and GraphGPS.