The reproduction result on MD17 do not match
xnuohz opened this issue · 1 comments
Hello, thanks for your nice work. I tried to reproduce the results of ShereNet on MD17 by following your tutorial and the parameters mentioned in Appendix D Table 8 of the paper. And they do not match:
- Your model in the tutorial
mae: uracil 0.2453906238079071
- Trained by myself from scratch
mae: uracil 4.0434064865112305
Below is my config. What is the best combination of parameters in Table 8? I am very appreciative if you have any suggestions!
Must batch_size be 1? I thought it's slow so changed it to 32.
import argparse
import torch
from dig.threedgraph.dataset import MD17
from dig.threedgraph.method import SphereNet
from dig.threedgraph.evaluation import ThreeDEvaluator
from pipline_on_md17 import run
parser = argparse.ArgumentParser(description='MD17 SphereNet')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--cutoff', type=float, default=5.0)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--hidden_channels', type=int, default=128)
parser.add_argument('--out_channels', type=int, default=1)
parser.add_argument('--int_emb_size', type=int, default=64)
parser.add_argument('--basis_emb_size_dist', type=int, default=8)
parser.add_argument('--basis_emb_size_angle', type=int, default=8)
parser.add_argument('--basis_emb_size_torsion', type=int, default=8)
parser.add_argument('--out_emb_channels', type=int, default=256)
parser.add_argument('--num_spherical', type=int, default=3)
parser.add_argument('--num_radial', type=int, default=6)
parser.add_argument('--eval_steps', type=int, default=50) # 50
parser.add_argument('--eval_start', type=int, default=200) # 200
parser.add_argument('--epochs', type=int, default=1000) # 500,[1000],2000
parser.add_argument('--batch_size', type=int, default=32) # [1],2,4,16,32
parser.add_argument('--vt_batch_size', type=int, default=448) # can try 64/128/256 based on the memory of your device
parser.add_argument('--lr', type=float, default=0.0005)
parser.add_argument('--lr_decay_factor', type=float, default=0.5)
parser.add_argument('--lr_decay_step_size', type=int, default=200)
parser.add_argument('--p', type=int, default=100)
parser.add_argument('--save_dir', type=str, default='my_md17_save')
parser.add_argument('--log_dir', type=str, default='my_md17_log')
args = parser.parse_args()
print(args)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
print('device',device)
model = SphereNet(energy_and_force=True, cutoff=args.cutoff, num_layers=args.num_layers,
hidden_channels=args.hidden_channels, out_channels=args.out_channels, int_emb_size=args.int_emb_size,
basis_emb_size_dist=args.basis_emb_size_dist, basis_emb_size_angle=args.basis_emb_size_angle, basis_emb_size_torsion=args.basis_emb_size_torsion, out_emb_channels=args.out_emb_channels,
num_spherical=args.num_spherical, num_radial=args.num_radial, envelope_exponent=5,
num_before_skip=1, num_after_skip=2, num_output_layers=3
)
model = model.to(device)
# Define model, loss, and evaluation
loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()
for data_name in ['aspirin', 'benzene2017', 'ethanol', 'malonaldehyde', 'naphthalene', 'salicylic', 'toluene', 'uracil']:
if data_name == 'uracil':
dataset = MD17(root='dataset/', name=data_name)
split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=1000, valid_size=1000, seed=42)
train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]
print('train, validaion, test:', data_name, len(train_dataset), len(valid_dataset), len(test_dataset))
# Train and evaluate
run3d = run()
run3d.run(device, train_dataset, valid_dataset, test_dataset, model, loss_func, evaluation,
epochs=args.epochs, batch_size=args.batch_size, vt_batch_size=args.vt_batch_size,
lr=args.lr, lr_decay_factor=args.lr_decay_factor, lr_decay_step_size=args.lr_decay_step_size, energy_and_force=True, p=args.p,
save_dir=args.save_dir, log_dir=args.log_dir, eval_steps=args.eval_steps, eval_start=args.eval_start)
Hi @xnuohz,
Thank you for your interest in our work. Please follow the hyperparameters in the tutorial. Based on my experiments, I think batch_size, lr and lr_decay_step_size are the most important hyperparameters, especially batch_size. batch_size=4 or 1 works better than 32. I think one of the possible reasons is that the training size is too small (most methods use 1000 or 950 training data). A smaller batch size may help to find a better local optimum.
Thanks.