aggregate node from sending nodes directly
TianrenWang opened this issue · 3 comments
What is the correct way to implement node aggregation from all sending nodes? Let's say I have nodes A, B, C, D, and E. A and B send to C, and D sends to E. I just want C = f(A) + f(B) and D = f(c). Edges are ignored in what I want to do.
The way that I have implemented it so far is with an InteractionNetwork:
graph_network = modules.InteractionNetwork(
edge_model_fn=lambda: snt.nets.MLP(output_sizes=[1]),
node_model_fn=lambda: snt.nets.MLP(output_sizes=[depth]))
global_block = blocks.GlobalBlock(global_model_fn=lambda: snt.nets.MLP(output_sizes=[depth]))
num_recurrent_passes = FLAGS.recurrences
previous_graphs = batch_of_graphs
for unused_pass in range(num_recurrent_passes):
previous_graphs = graph_network(previous_graphs)
previous_graphs = global_block(previous_graphs)
The output size of edge model is 1 because all edges are just tf.constant([1]).
The reason why I am asking is because my graph neural network is stuck on a loss value and I am wondering whether this graph neural network is implemented properly.
EDIT: Actually, I figured out the reason why my loss is stuck, but please verify that this is the correct way to do it anyways.
A and B send to C, and D sends to E. I just want C = f(A) + f(B) and D = f(c)
Do you mean: I just want C = f(A) + f(B) and E = f(D) ?
Otherwise I am not sure I get what you want to do.
A and B send to C, and D sends to E. I just want C = f(A) + f(B) and D = f(c)
Do you mean: I just want C = f(A) + f(B) and E = f(D) ?
Otherwise I am not sure I get what you want to do.
Apologies. Yes that was a typo.
In that case the computation you want is much simpler than the model you are trying to use, and it is essentially what it is usually referred to as a Graph Convolutional Network, because there is a single function "f" computed on the nodes, but there is no computation happening on the edges. This can be written bottom up in term of our broadcast and aggregation operators for the graphs in gn.blocks
:
model_fn = snt.nets.MLP(...)
for unused_pass in range(num_recurrent_passes):
# Update the node features with the function
updated_nodes = model_fn(previous_graphs.nodes)
temporary_graph = previous_graphs.replace(nodes=updated_nodes)
# Send the node features to the edges that are being sent by that node.
nodes_at_edges = gn.blocks.broadcast_sender_nodes_to_edges(temporary_graph)
temporary_graph = temporary_graph.replace(edges=nodes_at_edge)
# Aggregate the all of the edges received by every node.
nodes_with_aggregated_edges = gn.blocks.ReceivedEdgesToNodesAggregator(tf.math.unsorted_segment_sum)(temporary_graph)
previous_graphs = previous_graphs.replace(nodes=nodes_with_aggregated_edges)
More information about the ops and building blocks in gn.blocks
is available in our paper Relational inductive biases, deep learning, and graph networks.