pyg-team/pytorch_geometric

Does PyTorch Geometric have a "groupby"?

Closed this issue · 3 comments

❓ Questions & Help

I am trying to do classification of lists of dags.
As such, a single training example is a list of dags.
I've not seen any examples of anyone doing this.
I think that I can simply treat each dag in an example as part of the same Batch and use the .batch property to distinguish between them. (I plan on creating these batches manually instead of using a DataLoader to avoid shuffling.)

I would, however, also like to include multiple training examples in a batch.
I'm thinking that I could do this by using a custom property on the Data objects (called example) which would distinguish examples from one another.

Here's an example of what I mean. The following is a batch of 3 examples. The first second and third example have 5, 2, and 10 dags respectively. I should really have only y labels for each example rather than for each graph, but in this case, all graphs in an example share the same label.

In [6]: data                                                            
Out[6]: Batch(batch=[78], edge_attr=[69, 6], edge_index=[2, 69], example=[17], x=[78, 256], y=[17])

In [7]: data.batch                                                      
Out[7]: 
tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  2,  2,  2,  2,  3,  3,
         3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  6,  6,  6,
         6,  6,  6,  7,  7,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10, 10, 11, 11,
        11, 12, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16,
        16, 16, 16, 16, 16, 16])

In [8]: data.example                                                                                                                              
Out[8]: tensor([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

I would like to get the subgraphs for each separate dag and then group those by training example.
Is there a way to do this nicely in PyTorch Geometric?
Is there any resource you can point me to for learning more about this issue?

Thanks,
Jack

If I understand you correctly, data.batch already contains information about how to distinguish different dags. You should be able to convert your graphs back via batch.to_data_list() and group them via standard Python methods. Please correct me if I am missing something.

Yeah, you're correct. I'm probably just being paranoid about efficiency. (looking for standard functions to use to avoid writing code that might be inefficient...)

I'm trying to implement it now. Thanks!

You can always think about more efficient ways to do it, e.g. via torch_scatter, but that certainly depends on your use-case.