divelab/DIG

The reproduction result on MD17 do not match

xnuohz opened this issue · 1 comments

Hello, thanks for your nice work. I tried to reproduce the results of ShereNet on MD17 by following your tutorial and the parameters mentioned in Appendix D Table 8 of the paper. And they do not match:

  • Your model in the tutorial
mae: uracil 0.2453906238079071
  • Trained by myself from scratch
mae: uracil 4.0434064865112305

Below is my config. What is the best combination of parameters in Table 8? I am very appreciative if you have any suggestions!
Must batch_size be 1? I thought it's slow so changed it to 32.

import argparse
import torch

from dig.threedgraph.dataset import MD17
from dig.threedgraph.method import SphereNet
from dig.threedgraph.evaluation import ThreeDEvaluator
from pipline_on_md17 import run


parser = argparse.ArgumentParser(description='MD17 SphereNet')
parser.add_argument('--device', type=int, default=0)

parser.add_argument('--cutoff', type=float, default=5.0)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--hidden_channels', type=int, default=128)
parser.add_argument('--out_channels', type=int, default=1)
parser.add_argument('--int_emb_size', type=int, default=64)
parser.add_argument('--basis_emb_size_dist', type=int, default=8)
parser.add_argument('--basis_emb_size_angle', type=int, default=8)
parser.add_argument('--basis_emb_size_torsion', type=int, default=8)
parser.add_argument('--out_emb_channels', type=int, default=256)
parser.add_argument('--num_spherical', type=int, default=3)
parser.add_argument('--num_radial', type=int, default=6)

parser.add_argument('--eval_steps', type=int, default=50) # 50
parser.add_argument('--eval_start', type=int, default=200) # 200
parser.add_argument('--epochs', type=int, default=1000) # 500,[1000],2000
parser.add_argument('--batch_size', type=int, default=32) # [1],2,4,16,32
parser.add_argument('--vt_batch_size', type=int, default=448) # can try 64/128/256 based on the memory of your device
parser.add_argument('--lr', type=float, default=0.0005)
parser.add_argument('--lr_decay_factor', type=float, default=0.5)
parser.add_argument('--lr_decay_step_size', type=int, default=200)

parser.add_argument('--p', type=int, default=100)

parser.add_argument('--save_dir', type=str, default='my_md17_save')
parser.add_argument('--log_dir', type=str, default='my_md17_log')

args = parser.parse_args()
print(args)

device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
print('device',device)

model = SphereNet(energy_and_force=True, cutoff=args.cutoff, num_layers=args.num_layers, 
        hidden_channels=args.hidden_channels, out_channels=args.out_channels, int_emb_size=args.int_emb_size, 
        basis_emb_size_dist=args.basis_emb_size_dist, basis_emb_size_angle=args.basis_emb_size_angle, basis_emb_size_torsion=args.basis_emb_size_torsion, out_emb_channels=args.out_emb_channels, 
        num_spherical=args.num_spherical, num_radial=args.num_radial, envelope_exponent=5, 
        num_before_skip=1, num_after_skip=2, num_output_layers=3 
        )
model = model.to(device)

# Define model, loss, and evaluation
loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

for data_name in ['aspirin', 'benzene2017', 'ethanol', 'malonaldehyde', 'naphthalene', 'salicylic', 'toluene', 'uracil']:
    if data_name == 'uracil':
        dataset = MD17(root='dataset/', name=data_name)
        split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=1000, valid_size=1000, seed=42)
        train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]
        print('train, validaion, test:', data_name, len(train_dataset), len(valid_dataset), len(test_dataset))
        
        # Train and evaluate
        run3d = run()
        run3d.run(device, train_dataset, valid_dataset, test_dataset, model, loss_func, evaluation, 
                epochs=args.epochs, batch_size=args.batch_size, vt_batch_size=args.vt_batch_size, 
                lr=args.lr, lr_decay_factor=args.lr_decay_factor, lr_decay_step_size=args.lr_decay_step_size, energy_and_force=True, p=args.p,
                save_dir=args.save_dir, log_dir=args.log_dir, eval_steps=args.eval_steps, eval_start=args.eval_start)

Hi @xnuohz,

Thank you for your interest in our work. Please follow the hyperparameters in the tutorial. Based on my experiments, I think batch_size, lr and lr_decay_step_size are the most important hyperparameters, especially batch_size. batch_size=4 or 1 works better than 32. I think one of the possible reasons is that the training size is too small (most methods use 1000 or 950 training data). A smaller batch size may help to find a better local optimum.

Thanks.