| Website | Paper | Video | Raw Code |
ProG (Prompt Graph) is a library built upon PyTorch to easily conduct single or multiple task prompting for a pre-trained Graph Neural Networks (GNNs). The idea is derived from the paper: Xiangguo Sun, Hong Cheng, JIa Li, etc. All in One: Multi-task Prompting for Graph Neural Networks. KDD2023, in which they released the raw codes at Click. This repository is a polished version of the raw codes with Extremely Huge Changes and Updates:
- PyTorch 1.13.1
- torchmetrics 0.11.4
- torch_geometric 2.2.0
The following codes present a simple example on how to pre-train a GNN model via GraphCL. You can also find a integrated function pretrain()
in no_meta_demo.py
.
from ProG.utils import mkdir, load_data4pretrain
from ProG import PreTrain
mkdir('./pre_trained_gnn/')
pretext = 'GraphCL' # 'GraphCL', 'SimGRACE'
gnn_type = 'TransformerConv' # 'GAT', 'GCN'
dataname, num_parts, batch_size = 'CiteSeer', 200, 10
print("load data...")
graph_list, input_dim, hid_dim = load_data4pretrain(dataname, num_parts)
print("create PreTrain instance...")
pt = PreTrain(pretext, gnn_type, input_dim, hid_dim, gln=2)
print("pre-training...")
pt.train(dataname, graph_list, batch_size=batch_size,
aug1='dropN', aug2="permE", aug_ratio=None,
lr=0.01, decay=0.0001, epochs=100)
from ProG.prompt import GNN, LightPrompt
from torch import nn, optim
import torch
# load pre-trained GNN
gnn = GNN(100, hid_dim=100, out_dim=100, gcn_layer_num=2, gnn_type="TransformerConv")
pre_train_path = './pre_trained_gnn/{}.GraphCL.{}.pth'.format("CiteSeer", "TransformerConv")
gnn.load_state_dict(torch.load(pre_train_path))
print("successfully load pre-trained weights for gnn! @ {}".format(pre_train_path))
for p in gnn.parameters():
p.requires_grad = False
# prompt with hand-crafted answering template (no answering head tuning)
PG = LightPrompt(token_dim=100, token_num_per_group=100, group_num=6, inner_prune=0.01)
opi = optim.Adam(filter(lambda p: p.requires_grad, PG.parameters()),
lr=0.001, weight_decay=0.00001)
lossfn = nn.CrossEntropyLoss(reduction='mean')
The above codes are also integrated as a function model_create(dataname, gnn_type, num_class, task_type)
in this project.
from ProG.data import multi_class_NIG
import torch
train, test,_,_ = multi_class_NIG(dataname, num_class)
gnn, PG, opi, lossfn, _, _ = model_create(dataname, gnn_type, num_class, task_type)
prompt_epoch = 200 # 200
# training stage
PG.train()
emb0 = gnn(train.x, train.edge_index, train.batch)
for j in range(prompt_epoch):
pg_batch = PG.inner_structure_update()
pg_emb = gnn(pg_batch.x, pg_batch.edge_index, pg_batch.batch)
dot = torch.mm(emb0, torch.transpose(pg_emb, 0, 1))
sim = torch.softmax(dot, dim=1)
train_loss = lossfn(sim, train.y)
print('{}/{} training loss: {:.8f}'.format(j, prompt_epoch, train_loss.item()))
opi.zero_grad()
train_loss.backward()
opi.step()
For more detailed usage examples w.r.t prompt with answer tuning
, prompt with meta-learning
etc. Please check the demo in:
no_meta_demo.py
meta_demo.py
Multi-class node classification (100-shots)
| CiteSeer |
| ACC | Macro-F1 |
==========================================|
reported in the paper | 80.50 | 80.05 |
(Prompt) | |
------------------------------------------|
this version code | 81.00 | 81.23 |
(Prompt) | (run one time) |
==========================================|
reported in the paper | 80.00 | 80.05 |
(Prompt w/o h) | |
------------------------------------------|
this version code | 79.78 | 80.01 |
(Prompt w/o h) | (run one time) |
==========================================|
Note:
- Kindly note that the comparison takes the same pre-trained pth. The final results may vary depending on different pre-training states
- The above table is copied from this blog: https://github.com/sheldonresearch/ProG/blob/main/History.md#13-jul-2023
bibtex
@inproceedings{sun2023all,
title={All in One: Multi-Task Prompting for Graph Neural Networks},
author={Sun, Xiangguo and Cheng, Hong and Li, Jia and Liu, Bo and Guan, Jihong},
booktitle={Proceedings of the 26th ACM SIGKDD international conference on knowledge discovery \& data mining (KDD'23)},
year={2023}
}
- For More Information, Further discussion, Contact: Website
- Email: xiangguosun at cuhk dot edu dot hk