pyg-team/pytorch_geometric

Modifying edge connections on existing Data/Batch objects

Closed this issue · 4 comments

❓ Questions & Help

First of all, thanks for the great repo! I migrated from DGL and I have really positive experience using this repo.

I am working on an application that requires me to modify edge connections of a created graph after it passed through a GCN. My current strategy is to create new Data objects with new connections where its node embeddings are copied from the Data objects (batched) from the output of a GCN (so it is a Batch object). However, copying doesn't seem to be elegant to me. Are there any more efficient ways to modify the edge connections of an existing Batch (or Data) objects?

Sample code are provided as below:

# Create a batch of `Data` objects at the beginning
data_list = [Data(x=x, edge_index=original_edge_index) for x in mydata]
batch = Batch(data_list)
output = net(batch)

# Assign new connections using the node embedding from network outputs
output.edge_index = None
<code to be filled in>

Well, I do not think that it is necessary to copy any node embeddings, however I am not sure how to help you without more details, e.g., it is unclear to me how your GCN modifies edge connections. Are you learning edge scores to filter unimportant ones? Then something like the following will work:

row, col = edge_index
edge_score = self.lin(torch.cat([x[row], x[col]], dim=-1)).view(-1)  # [num_edges]
_, idx = edge_score.topk(k)
edge_index = edge_index[idx]
edge_weight = edge_score[idx]

Please be aware that you can not modify edge connections in a differentiable way without some kind of continuous edge probabilities, e.g., encoded as edge weights. There will be no grads w.r.t. to edge_index, only to its values!

Hi Matthias,

Thanks for the detailed reply. I think it would be clearer if I provide a bit of example code. Suppose that I have created two GCNConv networks (thus, each has its own weights): gcnconv1 and gcnconv2. Also, I have given a set of initial node embeddings stored in mydata.

I am trying to apply gcnconv1 on mydata using a set of connections edge_index1. Afterward, I use the output from gcnconv1 as new node embeddings to construct another graph to be convoluted by gcnconv2 but using edge_index2.

The following pseudocode is my current approach, but constructing the batch object (also, data_list) twice doesn't seem to be an efficient way to do this:

edge_index1 = <some connections based on the nodes>
edge_index2 = <different connections but the same number of nodes>
data_list = [Data(x=x, edge_index=edge_index1) for x in mydata ]
batch = Batch(data_list)
output = gcnconv1(batch) # doing graph convolution on nodes using `edge_index1`

# doing graph convolution on the nodes whose embeddings are the output of the previous network
# Also, the connection here is different 
new_node_embeddings = output
data_list = [Data(x=x, edge_index=edge_index2) for x in new_node_embeddings] 
batch = Batch(data_list)
final_output = gcnconv2(batch)

Ah I see. Your example should work just fine. Alternatively, you can write

edge_index1 = <some connections based on the nodes>
edge_index2 = <different connections but same nodes>
data_list = [Data(x=x, edge_index1=edge_index1, edge_index2=edge_index2) for x in mydata ]
batch = Batch.from_data_list(data_list)
output = gcnconv1(batch.x, batch.edge_index1)
final_output = gcnconv2(output, batch.edge_index2)

which should be a bit cleaner :)

Oh!! I see. That seems much more elegant. Thanks!