google-deepmind/jraph

Reference code for NLNN with multi-head attention?

Closed this issue · 9 comments

The graph networks paper points out the ability to express transformer like archs within graph net framework. Is there some example code of making multi-head attention work within this framework or the graph_nets library? I'm getting up on the segment_fn semantics.

Hey,

A transformer is effectively a Graph Attention Network applied to a fully connected graph.

In order to use Jraph to accomplish such a model you would:

  • create a GraphsTuple with a directed edge between every two nodes, this would be 2 * n^2 edges. This is accomplished by using the sender and receiver attributes.
  • Use the GAT model from jraph:
import jraph
import haiku as hk

def attention_query_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Attention Query Function."""
  net = hk.Sequential(
      [hk.Linear(128), jax.nn.relu,
       hk.Linear(128)])
  return net(feats)

def attention_logit_fn(sender_features, receiver_features, edge_features):
  """Attention Logit Function."""
  del edge_features  # transformers do not use explicit edge features
  return (sender_features * receiver_features).sum(-1)

transformer_like = jraph.GAT(attention_query_fn, attention_logit_fn)

Multi-head attention is unlikely to be as performant using a dedicated transformer library since we use a lot of gathers and scatters, but if you have a graph you want to apply multi head attention to, then you can use this.

I figured out what was confusing me. The example above is treating sender/receiver_features as you would input node features in a transformer, as simply taking the dot product of (N, d), (N, d) gives you (N, N) matrix which represents the fully connected adjacency matrix in a transformer. But here the first dimensions of sender/receiver are already edges, so you actually just want to do something like (sender_features * receiver_features).sum(-1).

This makes working with multihead attn surprisingly easy (I'm not sure re: performance but I may work on a PR with docs/examples of how to do this and maybe compare w/ dgl)

Thanks! I've updated the example.

You could also decoder the attention_logit_fn with @jax.vmap and then keep the np.dot product. The vmap simplifies some logic but can make things a little less transparent :)

I'm glad you think this makes it easier! We would be very happy if you'd like to contribute some docs/examples especially if you think this way of working is good !

Hey, quick update--I really enjoy the API and have been able to make some interesting architectures. However:

Multi-head attention is unlikely to be as performant using a dedicated transformer library

Unfortunately I think you're right, but I find even simpler message passing to be slow. Even with sparse connections for molecular graphs, preliminary experiments building on your examples (both haiku and flax based) show jax/jraph being orders of magnitude slower than expected (compared to pytorch/dgl impl. of transformers, GNN variants). GPU util sits at 0% most of the time almost regardless of batch size. I'm assuming this means the bottleneck isn't in any matrix multiplications but almost totally in the graph datastructure wrangling. Am I missing something obvious?

Hey,

That is a little concerning. Do you have a code gist or something I could look at?

I think it could be a few different things:

  • is the graph net accidentally not jitted?
  • Is there are lot of extra jitting going on because of a dynamic shape slipping in? This will cause recompilation on every call of the function.
  • Are the many jitted programs? (This would lead to going to and from the host a lot, so all the time would be spent in communication overhead)
  • Is Jax picking up your GPU? You can find this by checking jax.local_devices()

I would hope that is not in a fully jit compiled function, since that would imply sparse operations are slower in jax that pytorch.

Did a little bit more logging, now I see that for your minmal examples, the bottleneck is the data_utils.DataReader which for me takes 4.5s to get a batch of 128 graphs. Forward/backward pass are fine (backward .2s, so if dataloading took 0s, it'd be similar to my torch impl.). For the stuff I've done, I will do some more profiling because I'm using a lot more operations that look like: dense1(nodes)[receivers] * dense2(nodes)[senders] + dense3(globals_)[global_to_edge_broadcast_idx] Maybe this kind of ad-hoc indexing is not able to be optimized well.

How do you recommend dealing with the slow dataloader?

Hey,

I would recommend using either the tensorflow dataloader or the pytorch data loader, with a wrapper that converts DGL (or equivalent) to GraphsTuple.

I think, if you are using GPUS (not TPUs) that sort of indexing should be fine. My understanding is that bottlenecks with indexing are more due to hardware rather than XLA compilation (although there may be custom cuda kernels that can speed things up in special cases). Let me know if you find something to the opposite.

Hey,

I would recommend using either the tensorflow dataloader or the pytorch data loader, with a wrapper that converts DGL (or equivalent) to GraphsTuple.

I think, if you are using GPUS (not TPUs) that sort of indexing should be fine. My understanding is that bottlenecks with indexing are more due to hardware rather than XLA compilation (although there may be custom cuda kernels that can speed things up in special cases). Let me know if you find something to the opposite.

Thank you for this suggestion. Because also, similar to @sooheon, DataReader takes about 4 or more seconds, and I have been working on optimizing it throughout the week. Maybe I should use the TensorFlow data loader or the PyTorch data loader as you suggested.