thinng/GraphDTA

How to extract the graph data for a single graph from the DataLoader batch?

Closed this issue · 2 comments

Dear Team,
I'm trying to understand how graphs of variable sizes such as small molecules can be passed as a batch to a deep learning model with this code. While looking at the output from the DataLoader in training_validation.py, I get the following output with the default parameters set in the code.

`cuda_name: cuda:0
Learning rate: 0.0005
Epochs: 1000

running on GCNNet_davis
Pre-processed data found: data/processed/davis_train.pt, loading ...
Pre-processed data found: data/processed/davis_test.pt, loading ...
Batch(batch=[16399], c_size=[512], edge_index=[2, 36172], target=[512, 1000], x=[16399, 78], y=[512])
Batch(batch=[16332], c_size=[512], edge_index=[2, 36122], target=[512, 1000], x=[16332, 78], y=[512])
Batch(batch=[16269], c_size=[512], edge_index=[2, 36000], target=[512, 1000], x=[16269, 78], y=[512])
Batch(batch=[16193], c_size=[512], edge_index=[2, 35794], target=[512, 1000], x=[16193, 78], y=[512])
Batch(batch=[16418], c_size=[512], edge_index=[2, 36284], target=[512, 1000], x=[16418, 78], y=[512])`

I understand that 512 molecular graphs with their corresponding target proteins and affinity values are present in 1 batch of data. But I'm confused about how to extract the data corresponding to the 1st or 2nd graph in each batch from the DataLoader. I'm a beginner in Pytorch Geometric - so please explain in detail even if it appears as a very naive question. Also, another question is - does c_size set an upper limit to the maximum number of nodes in the batch? What will happen if we omit to provide the c_size attribute here?
Anticipating your reply and thanks in advance!

Have you tried training.py?

Yes. I found out how it works after looking at training.py. Previously I was unable to understand the custom DataLoader implementation. Looking at the PyTorch Geometric documentation it was clear that the collate function does the trick of combining graphs of variable sizes and how to get back the edges corresponding to a graph based on its index. Thanks to your wonderful code which helped me understand the details!