How to do cross-graph attention?
llan-ml opened this issue · 0 comments
Hi everyone,
I'm trying to re-implement the GMN model from deepmind in Jax. This model is designed to compute the similarity score between two graphs, and needs to compute the cross-graph node-wise attention weights between two graphs. Since different graphs can have different numbers of nodes, we need to do the cross-graph attention pair-by-pair.
The original implementation is written in Tensorflow, and given a node representation matrix, we can use tf.dynamically_partition
to split it into a list of node representation matrices, each of which corresponds to a graph, as follows:
def batch_block_pair_attention(data,
block_idx,
n_blocks,
similarity='dotproduct'):
"""Compute batched attention between pairs of blocks.
This function partitions the batch data into blocks according to block_idx.
For each pair of blocks, x = data[block_idx == 2i], and
y = data[block_idx == 2i+1], we compute
x_i attend to y_j:
a_{i->j} = exp(sim(x_i, y_j)) / sum_j exp(sim(x_i, y_j))
y_j attend to x_i:
a_{j->i} = exp(sim(x_i, y_j)) / sum_i exp(sim(x_i, y_j))
and
attention_x = sum_j a_{i->j} y_j
attention_y = sum_i a_{j->i} x_i.
Args:
data: NxD float tensor.
block_idx: N-dim int tensor.
n_blocks: integer.
similarity: a string, the similarity metric.
Returns:
attention_output: NxD float tensor, each x_i replaced by attention_x_i.
Raises:
ValueError: if n_blocks is not an integer or not a multiple of 2.
"""
if not isinstance(n_blocks, int):
raise ValueError('n_blocks (%s) has to be an integer.' % str(n_blocks))
if n_blocks % 2 != 0:
raise ValueError('n_blocks (%d) must be a multiple of 2.' % n_blocks)
sim = get_pairwise_similarity(similarity)
results = []
# This is probably better than doing boolean_mask for each i
partitions = tf.dynamic_partition(data, block_idx, n_blocks)
# It is rather complicated to allow n_blocks be a tf tensor and do this in a
# dynamic loop, and probably unnecessary to do so. Therefore we are
# restricting n_blocks to be a integer constant here and using the plain for
# loop.
for i in range(0, n_blocks, 2):
x = partitions[i]
y = partitions[i + 1]
attention_x, attention_y = compute_cross_attention(x, y, sim)
results.append(attention_x)
results.append(attention_y)
results = tf.concat(results, axis=0)
# the shape of the first dimension is lost after concat, reset it back
results.set_shape(data.shape)
return results
However, we do not have some functions similar to tf.dynamically_partition
in Jax.
@jg8610 Do you have any advice on how to do the cross-graph attention in Jax and Jraph, or do you have some similar cases internally? Thanks!