This repository is the official implementation of the AISTATS 2022 paper CF-GNNExplainer: Counterfactual Explanations for Graph Neural Networks.
To install requirements:
conda env create --file environment.yml
📋 This will create a conda environment called pytorch-geo
To train the original GNN models for the BA-shapes dataset in the paper, cd into src and run this command:
python train.py --dataset=syn1
📋 For the Tree-Cycles dataset, the dataset argument should be "syn4". For the Tree-Grid dataset, it should be "syn5". All hyperparameter settings are listed in the defaults, and all models have the same hyperparameters.
To train CF-GNNExplainer for each dataset, run the following commands:
python main_explain.py --dataset=syn1 --lr=0.01 --beta=0.5 --n_momentum=0.9 --optimizer=SGD
python main_explain.py --dataset=syn4 --lr=0.1 --beta=0.5 --optimizer=SGD
python main_explain.py --dataset=syn5 --lr=0.1 --beta=0.5 --optimizer=SGD
📋 This will create another folder in the main directory called 'results', where the results files will be stored.
To evaluate the CF examples, run the following command:
python evaluate.py --path=../results/<NAME OF RESULTS FILE>
📋 This will print out the values for each metric.
The pretrained models are available in the models folder
Our model achieves the following performance:
Model name | Dataset | Fidelity | Size | Sparsity | Accuracy |
---|---|---|---|---|---|
CF-GNNExplainer | Tree-Cycles | 0.21 | 2.09 | 0.90 | 0.94 |
CF-GNNExplainer | Tree-Grid | 0.07 | 1.47 | 0.94 | 0.96 |
CF-GNNExplainer | BA-Shapes | 0.39 | 2.39 | 0.99 | 0.96 |