geoelements/gns

In graph_network.py of gns module, it seems that the class InteractionNetwork does not update edge features.

Opened this issue · 5 comments

Describe the bug
I checked the code of the class InteractionNetwork in graph_network.py and found that in the message function, the new edge_features are not used to update the original ones but are directly returned. This results in the update function still having the initial tensor for edge_features, causing the residual connection to simply double the original tensor.

class InteractionNetwork(MessagePassing):
  def __init__(
      self,
      nnode_in: int,
      nnode_out: int,
      nedge_in: int,
      nedge_out: int,
      nmlp_layers: int,
      mlp_hidden_dim: int,
  ):
    # Aggregate features from neighbors
    super(InteractionNetwork, self).__init__(aggr='add')
    # Node MLP
    self.node_fn = nn.Sequential(*[build_mlp(nnode_in + nedge_out,
                                             [mlp_hidden_dim
                                              for _ in range(nmlp_layers)],
                                             nnode_out),
                                   nn.LayerNorm(nnode_out)])
    # Edge MLP
    self.edge_fn = nn.Sequential(*[build_mlp(nnode_in + nnode_in + nedge_in,
                                             [mlp_hidden_dim
                                              for _ in range(nmlp_layers)],
                                             nedge_out),
                                   nn.LayerNorm(nedge_out)])

  def forward(self,
              x: torch.tensor,
              edge_index: torch.tensor,
              edge_features: torch.tensor):
 
    # Save particle state and edge features
    x_residual = x
    edge_features_residual = edge_features
    # Start propagating messages.
    # Takes in the edge indices and all additional data which is needed to
    # construct messages and to update node embeddings.
    x, edge_features = self.propagate(
        edge_index=edge_index, x=x, edge_features=edge_features)

    return x + x_residual, edge_features + edge_features_residual

  def message(self,
              x_i: torch.tensor,
              x_j: torch.tensor,
              edge_features: torch.tensor) -> torch.tensor:
    # Concat edge features with a final shape of [nedges, latent_dim*3]
    edge_features = torch.cat([x_i, x_j, edge_features], dim=-1)
    edge_features = self.edge_fn(edge_features)        <-- Here is the question. At line 198 in graph_network.py
    return edge_features

  def update(self,
             x_updated: torch.tensor,
             x: torch.tensor,
             edge_features: torch.tensor):      <--  Edge_features are still original ones
    
    # Concat node features with a final shape of
    # [nparticles, latent_dim (or nnode_in) *2]
    x_updated = torch.cat([x_updated, x], dim=-1)
    x_updated = self.node_fn(x_updated)
    return x_updated, edge_features

To Reproduce
I instantiated this class separately to verify the issue. The code is as follows:

from gns.graph_network import *
import torch
from torch_geometric.data import Data
simulator = InteractionNetwork(
    nnode_in= 2,
    nnode_out= 2,
    nedge_in= 2,
    nedge_out= 2,
    nmlp_layers= 2,
    mlp_hidden_dim= 2
)
edge_index = torch.tensor([[0, 1],
                           [1, 0]], dtype=torch.long)
x = torch.tensor([[1,1], [2,2]], dtype=torch.float)
edge_attr = torch.tensor([[1,1], [2,2]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
print(edge_attr)
print(simulator(x=x,edge_index=edge_index,edge_features=edge_attr)[1])

The outputs are as follows:

tensor([[1., 1.],
        [2., 2.]])
tensor([[2., 2.],
        [4., 4.]])

Expected behavior
Maybe you can use a member variable to store this tensor. Such as:

  def message(self,
              x_i: torch.tensor,
              x_j: torch.tensor,
              edge_features: torch.tensor) -> torch.tensor:
 
    # Concat edge features with a final shape of [nedges, latent_dim*3]
    edge_features = torch.cat([x_i, x_j, edge_features], dim=-1)
    edge_features = self.edge_fn(edge_features)
    self.new_edge_features = edge_features
    return edge_features

  def update(self,
             x_updated: torch.tensor,
             x: torch.tensor,
             edge_features: torch.tensor):
    
    # Concat node features with a final shape of
    # [nparticles, latent_dim (or nnode_in) *2]
    x_updated = torch.cat([x_updated, x], dim=-1)
    x_updated = self.node_fn(x_updated)
    return x_updated, self.new_edge_features

Additional context
Maybe the code is correct while I missed something, or I misunderstood the formulas in the paper. I would greatly appreciate it if you could respond as soon as possible.

Thank you for leaving the comment. The edge feature is first updated in the message function. The update function does not update the edge feature, but takes in the updated edge feature computed from the message function, and returns itself.

Thank you for leaving the comment. The edge feature is first updated in the message function. The update function does not update the edge feature, but takes in the updated edge feature computed from the message function, and returns itself.

Thx for your reply. However, the update function does not seem to take in the updated edge features but instead uses the initial edge features. According to the PyG documentation, these edge features are the ones initially passed to the propagate function, not the updated edge features computed from the message function. If it is as you said, could you please explain this?

I found this in the PyG documentation. Based on their documentation, the propagate() first calls message() which takes in any argument as input which was initially passed to propagate. After the message() constructs the message, which is essentially the updated edge feature, aggregate() takes in the output of message computation. update() takes in the output of aggregation as first argument and any argument which was initially passed to propagate(). The update() function uses the output computed from message() (I refer to message_passing.py). I hope this helps. I will also double-check the part you pointed out where the edge feature doubles.

Thx! I found this issue during the process of stepping into the debug of the class InteractionNetwork, i.e., the edge features in the update() are the original ones. Maybe this can help you. Thanks again for your reply.

Thank you very much for your feedback! We have fixed the part you mentioned.