PyTorch implementation for Causal Attention for Interpretable and Generalizable Graph Classification
YongduoSui, Xiang Wang, Jiancan Wu, Min Lin, Xiangnan He, Tat-Seng Chua
In KDD 2022.
In this work, we take a causal look at the GNN modeling for graph classification. With our causal assumption, the shortcut feature serves as a confounder between the causal feature and prediction. It tricks the classifier to learn spurious correlations that facilitate the prediction in in-distribution (ID) test evaluation, while causing the performance drop in out-of-distribution (OOD) test data. To endow the classifier with better generalization, we propose the Causal Attention Learning (CAL) strategy, which discovers the causal patterns and mitigates the confounding effect of shortcuts. Specifically, we employ attention modules to estimate the causal and shortcut features of the input graph. We then parameterize the backdoor adjustment of causal theory — combine each causal feature with various shortcut features. It encourages the stable relationships between the causal estimation and the prediction, regardless of the changes in shortcut parts and distributions.
Please setup the environment following Requirements in this repository. Typically, you might need to run the following commands:
pip install torch==1.4.0
pip install torch-scatter==1.1.0 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
pip install torch-sparse==0.4.4 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
pip install torch-cluster==1.4.5 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
pip install torch-spline-conv==1.1.0 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
pip install torch-geometric==1.1.0
pip install torch-vision==0.5.0
lr=0.002
min=5e-6
b=0.9
model=CausalGCN
python main_syn.py --bias $b --lr $lr --min_lr $min --model $model
python main_real.py --model CausalGAT --dataset MUTAG
The backbone implementation is reference to https://github.com/chentingpc/gfn.