How to handle heterogeneous graph features?
sooheon opened this issue · 4 comments
Say you want to differentiate edge/node types explicitly, and have differently parametrized functions operate on each type. This is different from types implicitly being encoded in the input embeddings, because it won't allow dispatching to different functions. There is no jraph native way to hold types, so the only method I can see is to keep features as dicts with say "type" and "feature" keys. Then most of the update/aggregate functions would need to first filter by the appropriate type key (the default GraphNetwork probably won't work with this straight away). Any plans to support this kind of thing in jraph?
I guess another question is how this kind of "type dispatch" would work with jax jit compilation.
Hey, Sorry for the delay in getting back to you.
This is a question we've thought about quite a bit and at the moment there is no neat way to accomplish this.
I think you have two options:
(1) Make multiple jraph.GraphsTuples
(one for each type) run separate graph networks for each and then combine them.
Depending on how you wanted to handle your graph net, this could simply be a function that embeds all of the nodes/edges to the same embedding-sized vector, and then you can continue your latent processing. Finally you could use two graph networks to decode.
(2) Use jax.lax.cond to conditionally execute depending upon the 'type' feature. In this case, you must create 'dummy' feature values so that each 'feature' type in your dict has [n_node, feature] (or [n_edge, feature]) shape. This may be wasteful.
In terms of jit compilation you would have the same strategy as we currently do for option (2). For Option (1) you may set different padding sizes for each of the jraph.Graphstuples
. Let me know if that is helpful!
Thanks for the response. I agree it doesn't seem so easy right now, and the further you go down this path the more you're reimplementing something like DGL.
(1) makes sense, but it may get complicated when you want to also pass msgs between graphs (do you also create separate graphs for the A->B or B->A messages?)
(2) does seem wasteful, whether you have one giant map with type
key plus as many keys as there are types, or multiple dicts where each dict holds type
key and features
key with just a single vector.
Hi @sooheon , I wonder if you got around this? I'm looking for a data structure similar to https://pytorch-geometric.readthedocs.io/en/latest/notes/heterogeneous.html , where I can do message passing on the specified edge types only.