How to reproduce the results of baselines on OGBG?
ha-lins opened this issue · 2 comments
Hi @susheels
Thanks for your great work. I have a minor request that could you pls release the code of baselines (e.g., GraphCL) for OGBG. I think it's a bit difficult to adapt the test_minimax_ogbg.py
directly. It's really helpful if you could release them. Thanks a lot!
Hi, actually all the relevant code for say implementing GraphCL is already present in the unsupervised package. You can actually think of the published adgcl codebase as a library for testing out all the baselines.
All you need is a driver train/eval script calling the right learning algorithm (be it GraphCL, InfoGraph). I have included the file for running GraphCL baseline on ogbg for your reference. I will also add these example scripts into the main codebase once I get some time.
Few things to notice,
- The unsupervised.learning package has a module gsimclr which implements GSimCLR class- thats the learning algorithm behind GraphCL
- Because GraphCL performs stochastic non learnable data augmentation, we need to provide it. It is done again in unsupervised package under the utils module MyAugTransformer class.
- With these 2 main ingredients you can train GraphCL, be sure to check out the arguments/config.
Hope this helps!
import argparse
import logging
import torch
import numpy as np
import random
from unsupervised.encoder import MoleculeEncoder
from unsupervised.embedding_evaluation import EmbeddingEvaluation
from unsupervised.learning import GSimCLR
from unsupervised.utils import MyAugTransformer
from sklearn.linear_model import LogisticRegression, Ridge
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.data import DataLoader
from ogb.graphproppred import Evaluator
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
np.random.seed(seed)
random.seed(seed)
def run(args):
if args.log_file_name is not None:
logging.basicConfig(filename=args.log_file_name, filemode='w', level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
else:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info("Using Device: %s" % device)
logging.info("Seed: %d" % args.seed)
logging.info(args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
setup_seed(args.seed)
evaluator = Evaluator(name=args.dataset)
# Download and process data at './dataset/ogbg_molhiv/'
my_transform = MyAugTransformer(type_code=args.aug_type, aug_ratio=args.aug_ratio)
dataset = PygGraphPropPredDataset(name=args.dataset, root='./original_datasets/', transform=my_transform)
dataset_eval = PygGraphPropPredDataset(name=args.dataset, root='./original_datasets/')
split_idx = dataset_eval.get_idx_split()
train_loader = DataLoader(dataset_eval[split_idx["train"]], batch_size=128, shuffle=True)
valid_loader = DataLoader(dataset_eval[split_idx["valid"]], batch_size=128, shuffle=False)
test_loader = DataLoader(dataset_eval[split_idx["test"]], batch_size=128, shuffle=False)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
model = GSimCLR(MoleculeEncoder(emb_dim=args.emb_dim, num_gc_layers=args.num_gc_layers, drop_ratio=args.drop_ratio, pooling_type=args.pooling_type),
proj_hidden_dim=args.emb_dim).to(device)
# print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
if 'classification' in dataset.task_type:
ee = EmbeddingEvaluation(LogisticRegression(dual=False, fit_intercept=True, max_iter=5000),
evaluator, dataset.task_type, dataset.num_tasks, device, params_dict=None,
param_search=True)
elif 'regression' in dataset.task_type:
ee = EmbeddingEvaluation(Ridge(fit_intercept=True, normalize=True, copy_X=True, max_iter=5000),
evaluator, dataset.task_type, dataset.num_tasks, device, params_dict=None,
param_search=True)
else:
raise NotImplementedError
model.eval()
train_score, val_score, test_score = ee.embedding_evaluation(model.encoder, train_loader, valid_loader, test_loader)
logging.info("Before training Embedding Eval Ridge Scores: Train: {} Val: {} Test: {}".format(train_score, val_score, test_score))
valid_curve = []
test_curve = []
train_curve = []
for epoch in range(1, args.epochs + 1):
loss_all = 0
model.train()
for data in dataloader:
batch, batch_aug = data
optimizer.zero_grad()
node_num, _ = batch.x.size()
batch = batch.to(device)
x = model(batch.batch, batch.x, batch.edge_index, batch.edge_attr)
batch_aug = batch_aug.to(device)
x_aug = model(batch_aug.batch, batch_aug.x, batch_aug.edge_index, batch_aug.edge_attr)
loss = model.calc_loss(x, x_aug)
loss_all += loss.item() * batch.num_graphs
loss.backward()
optimizer.step()
fin_loss = loss_all / len(dataloader)
logging.info('Epoch {}, Unsupervised Loss {}'.format(epoch, fin_loss))
model.eval()
train_score, val_score, test_score = ee.embedding_evaluation(model.encoder, train_loader, valid_loader,
test_loader)
logging.info(" Train: {} Val: {} Test: {}".format(train_score, val_score, test_score))
train_curve.append(train_score)
valid_curve.append(val_score)
test_curve.append(test_score)
if 'classification' in dataset.task_type:
best_val_epoch = np.argmax(np.array(valid_curve))
best_train = max(train_curve)
else:
best_val_epoch = np.argmin(np.array(valid_curve))
best_train = min(train_curve)
logging.info('FinishedTraining!')
logging.info('BestEpoch: {}'.format(best_val_epoch))
logging.info('BestTrainScore: {}'.format(best_train))
logging.info('BestValidationScore: {}'.format(valid_curve[best_val_epoch]))
logging.info('FinalTestScore: {}'.format(test_curve[best_val_epoch]))
return valid_curve[best_val_epoch], test_curve[best_val_epoch]
def arg_parse():
parser = argparse.ArgumentParser(description='Gsimclr ogbg-mol*')
parser.add_argument('--dataset', type=str, default='ogbg-molesol',
help='Dataset')
parser.add_argument('--lr', type=float, default=0.001,
help='Learning rate.')
parser.add_argument('--num_gc_layers', type=int, default=5,
help='Number of GNN layers before pooling')
parser.add_argument('--pooling_type', type=str, default='standard',
help='GNN Pooling Type Standard/Layerwise')
parser.add_argument('--emb_dim', type=int, default=300,
help='embedding dimension')
parser.add_argument('--batch_size', type=int, default=32,
help='batch size')
parser.add_argument('--drop_ratio', type=float, default=0.0,
help='Dropout Ratio / Probability')
parser.add_argument('--epochs', type=int, default=20,
help='Train Epochs')
parser.add_argument('--aug_type', type=str, default='dnodes+subgraph',
help='Augmentation Type')
parser.add_argument('--aug_ratio', type=float, default=0.2,
help='Augmentation Drop Nodes/Subgraph/Edges Ratio')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--log_file_name', type=str, default=None)
return parser.parse_args()
if __name__ == '__main__':
args = arg_parse()
run(args)
Thank you for your helpful reply. Your code style is very elegant and clear. I wish you a Happy New Year in advance! : )