data shuffling during the training of OGB example with padding and JIT support
skrsna opened this issue · 3 comments
Hi,
Thanks for such a great library. I have a question about data shuffling during training in the ogb_example. Can we shuffle the data during the training epochs like Dataloader(shuffle=True)
in pytorch? Will this affect padding and JIT compilation? Sorry if this already implemented in make_generator. I'm not sure if I understand this correctly, during the initial epoch each batch is padded to some n
dimensions and for the next epoch we shuffle the indices and make new batches does this entail new padding and JIT compilation?
I implemented the molecular graph featurization described in MPNN paper and loading the dataset using pytorch dataloader API like so:
from torch.utils import data
class CSVDataLoader(data.DataLoader):
def __init__(self, dataset, batch_size=1,
shuffle=False, sampler=None,
batch_sampler=None, num_workers=0,
pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super(self.__class__, self).__init__(dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_molgraphs,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn)
I used the following collate_fn
:
def collate_molgraphs(data):
smiles, graph_tuples, labels = map(list, zip(*data))
batched_graph_tuple = jraph.batch(graph_tuples)
padded_graph_tuple = pad_to_nearest_power_of_two(batched_graph_tuple)
batched_labels = jnp.stack(labels, axis=0)
return smiles, padded_graph_tuple, batched_labels
and iteration:
dataloader = CSVDataLoader(dataset, batch_size=128,drop_last=True)
for i, batch in enumerate(dataloader):
smiles, batched_graph, labels = batch
labels = jnp.concatenate([labels, jnp.array([[0]])])
loss, grad = compute_loss_fn(params, batched_graph, labels)
If I use a bigger batch size the padded dimension is big enough (jraph
is padding to larger powers of 2 e.g. 2048 or 4096) that it's the same with shuffle=True
or shuffle=False
so I assume it's not leading to any subsequent JIT compilations during the training but just iterating over the same size dataset with shuffle=True
is approximately ~2.1X slower than shuffle=False
. Dataloading times are not a bottle neck for my use case but just wondering what is leading to slower data loading with shuffle=False
. Thanks!
Hey thanks for your question!
You're right that using the pytorch dataloader is a good solution if you have another data set. We just wrote our own to manage dependencies and be didactic 👍
All your code looks correct from what I can see, but without checking the data sizes I can't be sure that is isn't jit. It would seem strange for shuffling to cause this though, since unless there is some intrinsic order to the dataset, you wouldn't expect a different distribution of batch shapes.
A trick you can use to check jit compilation is to add print('JIT Compiling)
within you compute_loss_fn
. This will print every time the function is compiled, but not when it is executed. This will tell you if you have excessive compilation.
Other than that, it would seem that the actual data loading is the slow bit and so probably worth checking how shuffling is done within the pytorch data loader.
Thanks for the info and the trick about print inside compute_loss_fn
. The issue can be closed!