susheels/adgcl

How to reproduce the results of baselines on OGBG?

ha-lins opened this issue · 2 comments

Hi @susheels

Thanks for your great work. I have a minor request that could you pls release the code of baselines (e.g., GraphCL) for OGBG. I think it's a bit difficult to adapt the test_minimax_ogbg.py directly. It's really helpful if you could release them. Thanks a lot!

Hi, actually all the relevant code for say implementing GraphCL is already present in the unsupervised package. You can actually think of the published adgcl codebase as a library for testing out all the baselines.

All you need is a driver train/eval script calling the right learning algorithm (be it GraphCL, InfoGraph). I have included the file for running GraphCL baseline on ogbg for your reference. I will also add these example scripts into the main codebase once I get some time.

Few things to notice,

  1. The unsupervised.learning package has a module gsimclr which implements GSimCLR class- thats the learning algorithm behind GraphCL
  2. Because GraphCL performs stochastic non learnable data augmentation, we need to provide it. It is done again in unsupervised package under the utils module MyAugTransformer class.
  3. With these 2 main ingredients you can train GraphCL, be sure to check out the arguments/config.

Hope this helps!


import argparse
import logging
import torch
import numpy as np
import random

from unsupervised.encoder import MoleculeEncoder
from unsupervised.embedding_evaluation import EmbeddingEvaluation
from unsupervised.learning import GSimCLR
from unsupervised.utils import MyAugTransformer


from sklearn.linear_model import LogisticRegression, Ridge

from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.data import DataLoader
from ogb.graphproppred import Evaluator

def setup_seed(seed):
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)
	torch.backends.cudnn.deterministic = True
	np.random.seed(seed)
	random.seed(seed)

def run(args):
	if args.log_file_name is not None:
		logging.basicConfig(filename=args.log_file_name, filemode='w', level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
	else:
		logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	logging.info("Using Device: %s" % device)
	logging.info("Seed: %d" % args.seed)
	logging.info(args)

	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	setup_seed(args.seed)

	evaluator = Evaluator(name=args.dataset)
	# Download and process data at './dataset/ogbg_molhiv/'
	my_transform = MyAugTransformer(type_code=args.aug_type, aug_ratio=args.aug_ratio)

	dataset = PygGraphPropPredDataset(name=args.dataset, root='./original_datasets/', transform=my_transform)
	dataset_eval = PygGraphPropPredDataset(name=args.dataset, root='./original_datasets/')
	split_idx = dataset_eval.get_idx_split()
	train_loader = DataLoader(dataset_eval[split_idx["train"]], batch_size=128, shuffle=True)
	valid_loader = DataLoader(dataset_eval[split_idx["valid"]], batch_size=128, shuffle=False)
	test_loader = DataLoader(dataset_eval[split_idx["test"]], batch_size=128, shuffle=False)

	dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
	model = GSimCLR(MoleculeEncoder(emb_dim=args.emb_dim, num_gc_layers=args.num_gc_layers, drop_ratio=args.drop_ratio, pooling_type=args.pooling_type),
	                proj_hidden_dim=args.emb_dim).to(device)
	# print(model)
	optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

	if 'classification' in dataset.task_type:
		ee = EmbeddingEvaluation(LogisticRegression(dual=False, fit_intercept=True, max_iter=5000),
		                         evaluator, dataset.task_type, dataset.num_tasks, device, params_dict=None,
		                         param_search=True)
	elif 'regression' in dataset.task_type:
		ee = EmbeddingEvaluation(Ridge(fit_intercept=True, normalize=True, copy_X=True, max_iter=5000),
		                         evaluator, dataset.task_type, dataset.num_tasks, device, params_dict=None,
		                         param_search=True)
	else:
		raise NotImplementedError

	model.eval()
	train_score, val_score, test_score = ee.embedding_evaluation(model.encoder, train_loader, valid_loader, test_loader)
	logging.info("Before training Embedding Eval Ridge Scores: Train: {} Val: {} Test: {}".format(train_score, val_score, test_score))

	valid_curve = []
	test_curve = []
	train_curve = []

	for epoch in range(1, args.epochs + 1):
		loss_all = 0
		model.train()
		for data in dataloader:
			batch, batch_aug = data
			optimizer.zero_grad()

			node_num, _ = batch.x.size()
			batch = batch.to(device)

			x = model(batch.batch, batch.x, batch.edge_index, batch.edge_attr)

			batch_aug = batch_aug.to(device)

			x_aug = model(batch_aug.batch, batch_aug.x, batch_aug.edge_index, batch_aug.edge_attr)

			loss = model.calc_loss(x, x_aug)
			loss_all += loss.item() * batch.num_graphs
			loss.backward()
			optimizer.step()

		fin_loss = loss_all / len(dataloader)

		logging.info('Epoch {}, Unsupervised Loss {}'.format(epoch, fin_loss))

		model.eval()

		train_score, val_score, test_score = ee.embedding_evaluation(model.encoder, train_loader, valid_loader,
		                                                             test_loader)

		logging.info(" Train: {} Val: {} Test: {}".format(train_score, val_score, test_score))

		train_curve.append(train_score)
		valid_curve.append(val_score)
		test_curve.append(test_score)

	if 'classification' in dataset.task_type:
		best_val_epoch = np.argmax(np.array(valid_curve))
		best_train = max(train_curve)
	else:
		best_val_epoch = np.argmin(np.array(valid_curve))
		best_train = min(train_curve)

	logging.info('FinishedTraining!')
	logging.info('BestEpoch: {}'.format(best_val_epoch))
	logging.info('BestTrainScore: {}'.format(best_train))
	logging.info('BestValidationScore: {}'.format(valid_curve[best_val_epoch]))
	logging.info('FinalTestScore: {}'.format(test_curve[best_val_epoch]))

	return valid_curve[best_val_epoch], test_curve[best_val_epoch]

def arg_parse():
	parser = argparse.ArgumentParser(description='Gsimclr ogbg-mol*')

	parser.add_argument('--dataset', type=str, default='ogbg-molesol',
	                    help='Dataset')

	parser.add_argument('--lr', type=float, default=0.001,
	                    help='Learning rate.')
	parser.add_argument('--num_gc_layers', type=int, default=5,
	                    help='Number of GNN layers before pooling')
	parser.add_argument('--pooling_type', type=str, default='standard',
	                    help='GNN Pooling Type Standard/Layerwise')
	parser.add_argument('--emb_dim', type=int, default=300,
	                    help='embedding dimension')
	parser.add_argument('--batch_size', type=int, default=32,
	                    help='batch size')

	parser.add_argument('--drop_ratio', type=float, default=0.0,
	                    help='Dropout Ratio / Probability')

	parser.add_argument('--epochs', type=int, default=20,
	                    help='Train Epochs')

	parser.add_argument('--aug_type', type=str, default='dnodes+subgraph',
	                    help='Augmentation Type')

	parser.add_argument('--aug_ratio', type=float, default=0.2,
	                    help='Augmentation Drop Nodes/Subgraph/Edges Ratio')

	parser.add_argument('--seed', type=int, default=0)

	parser.add_argument('--log_file_name', type=str, default=None)

	return parser.parse_args()

if __name__ == '__main__':
	args = arg_parse()
	run(args)

Thank you for your helpful reply. Your code style is very elegant and clear. I wish you a Happy New Year in advance! : )