This is a PyTorch implementation of the GraphMixup, and the code includes the following modules:
Dataset Loader (Cora, BlagCatalog, and Wiki-CS)
Various Architectures (GCN, SAGE, GAT, and SEM)
Five compared baselines (Origin, Over-Sampling, Re-weight, SMOTE, and Embed-SMOTE)
Training paradigm (joint learning, pre-training, and fine-tuning) for node classification on three datasets
Visualization and evaluation metrics
- networkx==2.5
- numpy==1.19.2
- scikit-learn==0.24.1
- scipy==1.5.2
- torch==1.6.0
- train() -- Train a new model for node classification task on the Cora, BlagCatalog, and Wiki-CS datasets
- test() -- Test the learned model for node classification task on the Cora, BlagCatalog, and Wiki-CS datasets
- save_model() -- Save the pre-trained model
- load_model() -- Load model for fine-tuning
- load_cora() -- Load Cora Dataset
- load_BlogCatalog() -- Load BlogCatalog Dataset
- load_wiki_cs() -- Load Wiki-CS Dataset
- GraphConvolution() -- GCN Layer
- SageConv() -- SAGE Layer
- SemanticLayer() -- Semantic Feature Layer
- GraphAttentionLayer() -- GAT Layer
- PairwiseDistance() -- Perform self-supervised Local-Path Prediction
- DistanceCluster() -- Perform self-supervised Global-Path Prediction
- src_upsample() -- Perform interpolation in the input space
- src_smote() -- Perform interpolation in the embedding space
- mixup() -- Perform mixup in the semantic relation space
- GNN_env() -- Calculate rewards, perform actions, and update states
- isTerminal() -- Determine whether the termination conditions have been met
Install the required dependency packages
To get the results on a specific dataset, first run with proper hyperparameters to perform pre-training
python --dataset data_name --setting pre-train
where the data_name is one of the 3 datasets (CCora, BlagCatalog, and Wiki-CS). The pre-trained model will be saved to the corresponding checkpoint folder in ./checkpoint for evaluation.
- To fine-tune the pre-trained model, run
python --dataset data_name --setting fine-tune --load model_path
where the model_path is the path where the pre-trained model is saved.
- We provide five compared baselines in this code. They can be configured via the '--setting' arguments:
- Origin: Vanilla backbone models with '--setting raw'
- Over-Sampling: Repeat nodes in the minority classes with '--setting over-sampling'
- Re-weight: Give samples from minority classes a larger weight when calculating the loss with '--setting re-weight'
- SMOTE: Interpolation in the input space with '--setting smote'
- Embed-SMOTE: Perform SMOTE in the intermediate embedding space with '--setting embed_smote'
Use Embed-SMOTE as an example:
python --dataset cora --setting embed_smote
If you find this project useful for your research, please use the following BibTeX entry.
title={Graphmixup: Improving class-imbalanced node classification on graphs by self-supervised context prediction},
author={Wu, Lirong and Lin, Haitao and Gao, Zhangyang and Tan, Cheng and Li, Stan and others},
journal={arXiv preprint arXiv:2106.11133},