/Boosting-multitask-learning-on-graphs

A boosting procedure for multitask learning on graph-structured data

Primary LanguagePythonMIT LicenseMIT

Overview

This repository provided an algorithm as a generic boosting procedure for multitask learning on graphs, which is presented at KDD'23, Long Beach. The algorithm clusters graph learning tasks into multiple groups and trains one graph neural network for each group. In the procedure, we first model higher-order task affinities by sampling random task subsets and evaluating multitask performances. Then, we find related task groupings through clustering task affinity scores.

Data Preparation

Community detection. We provide the datasets for conducting community detection named data.zip under the ./data/ folder used. Unzip the file under the folder, then one can directly load them in the code.

Molecule property prediction. We conduct experiments on graph multi-task learning datasets on molecule graph prediction tasks. Our code directly downloads the datasets inside the script. Please pre-install the ogb and torch-geometric packages.

Section 1: Overlapping Community detection

Use train_multitask.py for the experiments of training a GNN on community detection tasks. Please specify the following key parameters:

  • --dataset specifies the dataset. Please choose among amazon, youtube, dblp, and livejournal.
  • --model specifies the gnn model. We mainly used sign in our experiments.
  • --task_idxes specifies the indexes of tasks that the model is trained on. Use the numbers from 0 up to the number of tasks. Use space in between the indexes.
  • --save_name specifies the filename that saves the training results. Specify a name for the file if one will use the results later.

We show an example below that trains a SIGN model on the Youtube dataset:

python train_multitask.py --dataset youtube --feature_dim 128\
    --model sign --num_layers 3 --hidden_channels 256 --lr 0.01 --dropout 0.1 --mlp_layers 2\
    --evaluator f1_score --sample_method decoupling --batch_size 1000 --epochs 100 --device 2 --runs 3\
    --save_name test --task_idxes 0 1 2 3 4 

Use train_sample_tasks.py for sampling tasks and evaluating MTL performance on the trained models. Please specify the following key parameters.

  • --dataset specifies the dataset. Please choose among amazon, youtube, dblp, and livejournal.
  • --num_samples specifies the number of sampled subsets.
  • --min_task_num specifies the minimum number of tasks in a subset.
  • --max_task_num specifies the maximum number of tasks in a subset.
  • --task_set_name specifies the file name for saving the sampled subsets.
  • --save_name specifies the filename that saves the training results.

We show an example below that samples subsets of tasks on the Youtube dataset:

python train_sample_tasks.py --dataset youtube\
    --model sign --num_layers 3 --hidden_channels 256 --lr 0.01 --dropout 0.1\
    --evaluator f1_score --sample_method decoupling --batch_size 1000 --epochs 100 --device 2 --runs 1\
    --target_tasks none --num_samples 2000 --min_task_num 5 --max_task_num 5\
    --task_set_name sample_youtube --save_name sample_youtube

Lastly, we conduct clustering on task affinity scores and generate task indexes for each task group. See an example to generate task groupings in /notebooks/run_task_grouping.py.


Section 2: Molecule property prediction

Use train_multitask.py and change the --dataset to alchemy_full, QM9, or molpcba. The other parameters follow the ones used in community detection.

We show an example below to launch experiments on the alchemy, QM9, and ogb-molpcba datasets.

python train_multitask.py --dataset alchemy_full --model gine\
        --criterion regression --evaluator mae --hidden_channels 64 \
        --epochs 200 --downsample 1.0\
        --device 0 --runs 3 \
        --save_name test --task_idx 0 1 2 3 4

python train_multitask.py --dataset QM9 --model gine\
        --criterion regression --evaluator mae --hidden_channels 64 \
        --epochs 200 --downsample 1.0\
        --device 0 --runs 3 \
        --save_name test --task_idx 0 1 2 3 4

python train_multitask.py --dataset molpcba --model gine\
    --criterion multilabel --evaluator precision --hidden_channels 300 \
    --epochs 100 --downsample 1.0 --batch_size 32\
    --device 1 --runs 3 --mnt_mode max --eval_separate\
    --save_name test --task_idx 0 1 2 3 4

Use train_sample_tasks.py and change the --dataset to alchemy_full, QM9, or molpcba. The other parameters follow the ones used in community detection. For example:

python train_sample_tasks.py --dataset alchemy_full\
    --epochs 20 --downsample 0.2 --device 3\
    --num_samples 200 --min_task_num 4 --max_task_num 4\
    --task_set_name sample_alchemy --save_name sample_alchemy

Requirements

Please install the requirements before launching the experiments:

pip install -r requirements.txt

We list the key packages used in our code:

  • python>=3.6
  • torch>=1.10.0
  • torch-geometric>=2.0.3
  • pytorch-lightning>1.5.10
  • torchmetrics>=0.8.2
  • ogb>=1.3.4

Citation

If you find this repository useful or happen to use it in a research paper, please cite our work with the following bib information.

@article{li2023boosting,
  title={Boosting Multitask Learning on Graphs through Higher-Order Task Affinities},
  author={Li, Dongyue and Ju, Haotian and Sharma, Aneesh and Zhang, Hongyang R},
  journal={SIGKDD Conference on Knowledge Discovery and Data Mining},
  year={2023}
}

Acknowledgment

Thanks to the authors of the following repositories for providing their implementation publicly available, which greatly helps us develop this code.