Source code of NeurIPS 2023 "Provable Training of Graph Contrastive Learning"

Primary LanguagePythonMIT LicenseMIT

Source code of our NeurIPS 2023 Spotlight paper "Provable Training for Graph Contrastive Learning"

Environment Settings

Here we list some important python packages we used:

torch == 1.12.1
# PyG
torch-geometric == 2.2.0
torch-cluster == 1.6.0
torch-scatter == 2.1.0
torch-sparse == 0.6.16
torch-spline-conv == 1.2.1
dgl == 1.0.2
PyGCL == 0.1.2

To install PyG, we suggest to follow the official guide: https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html To install DGL, please follow https://docs.dgl.ai/install/index.html


Fisrt, make the directories for datasets and bounds to save

mkdir ~/datasets
mkdir ~/datasets/bounds

Then, go into the directory of a model. If you want to set the parameters, you should modify the ocnfiguration files in the directory ("config.yaml" for GRACE, "config/default.yaml" for COSTA, "param/{dataset_name}.json" for GCA and ProGCL). The following is the command line to run each model (dataset used is Cora for example):

# original GRACE
python train.py --dataset Cora --gpu_id 0
python train.py --dataset Cora --gpu_id 0 --use_pot --kappa 0.4

cd GCA # cd ProGCL
# original GCA
python train.py --dataset Cora --param local:cora.json --device cuda:0
python train.py --dataset Cora --param local:cora.json --device cuda:0 --use_pot --kappa 0.3

cd COSTA/src
# original COSTA
python main.py --dataset Cora --gpu_id 0
python main.py --dataset Cora --gpu_id 0 --use_pot --kappa 0.2

The result will be appended to the file "res/{dataset_name}_base_temp.csv" and "res/{dataset_name}_pot_temp.csv" respectively. You can also set the parameter "save_file" to specify the file to save results. We use minibatch to reduce the memory occupation, you can modify it in the code. To use minibatch for POT, set "pot_batch", usually 256/512/1024 will work:

# GCA when use pot_batch on BlogCatalog
python train.py --dataset BlogCatalog --gpu_id 0 --use_pot --kappa 0.3 --pot_batch 1024