microsoft/tf2-gnn

How to add a RGAT layer to a custom keras model?

Jorvan758 opened this issue · 4 comments

I have been struggling for a while trying to do this, but I'm still more or less a noob, so I precise your help.

Here is one of my attemps:

from keras import Model, layers

inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 3)),name="Input_X")
inputLayer_A = layers.Input(shape=tuple(tf.TensorShape(dims=(None, 2)) for _ in range(3)),name="Input_A")
rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
                    'message_activation_before_aggregation': False,
                    'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(MessagePassingInput(inputLayer_X, inputLayer_A))
modelo = Model([inputLayer_X, inputLayer_A], rgatLayer_1, name="The_model")

Which returns:


TypeError Traceback (most recent call last)
in ()
1 inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 3)),name="Input_X")
----> 2 inputLayer_A = layers.Input(shape=tuple(tf.TensorShape(dims=(None, 2)) for _ in range(3)),name="Input_A")
3 rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
4 'message_activation_before_aggregation': False,
5 'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(MessagePassingInput(inputLayer_X, inputLayer_A))

1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.traceback)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb

/usr/local/lib/python3.7/dist-packages/six.py in raise_from(value, from_value)

TypeError: Dimension value must be integer or None or have an index method, got value 'TensorShape([None, 2])' with type '<class 'tensorflow.python.framework.tensor_shape.TensorShape'>'

And here's another:

from keras import Model, layers

inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 3)),name="Input_X")
inputLayer_A1 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A1")
inputLayer_A2 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A2")
inputLayer_A3 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A3")
rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
                    'message_activation_before_aggregation': False,
                    'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(MessagePassingInput(inputLayer_X,
                                                                                                               [inputLayer_A1, inputLayer_A2, inputLayer_A3]))
modelo = Model([inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3], rgatLayer_1, name="The_model")

That yields:


TypeError Traceback (most recent call last)
in ()
6 'message_activation_before_aggregation': False,
7 'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(MessagePassingInput(inputLayer_X,
----> 8 [inputLayer_A1, inputLayer_A2, inputLayer_A3]))
9 modelo = Model([inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3], rgatLayer_1, name="The_model")

1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.traceback)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb

/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
697 except Exception as e: # pylint:disable=broad-except
698 if hasattr(e, 'ag_error_metadata'):
--> 699 raise e.ag_error_metadata.to_exception(e)
700 else:
701 raise

TypeError: Exception encountered when calling layer "RGAT_1" (type RGAT).

in user code:

File "/usr/local/lib/python3.7/dist-packages/tf2_gnn/layers/message_passing/message_passing.py", line 116, in call  *
    messages_per_type = self._calculate_messages_per_type(
File "/usr/local/lib/python3.7/dist-packages/tf2_gnn/layers/message_passing/message_passing.py", line 190, in _calculate_messages_per_type  *
    type_to_num_incoming_edges = calculate_type_to_num_incoming_edges(
File "/usr/local/lib/python3.7/dist-packages/tf2_gnn/layers/message_passing/message_passing.py", line 256, in calculate_type_to_num_incoming_edges  *
    num_incoming_edges = tf.scatter_nd(

TypeError: Value passed to parameter 'indices' has DataType float32 not in list of allowed values: int32, int64

Call arguments received:
• inputs=MessagePassingInput(node_embeddings='tf.Tensor(shape=(None, None, 3), dtype=float32)', adjacency_lists=['tf.Tensor(shape=(None, None, 2), dtype=float32)', 'tf.Tensor(shape=(None, None, 2), dtype=float32)', 'tf.Tensor(shape=(None, None, 2), dtype=float32)'])
• training=False

I'll keep trying to overcome it (and will update if I do so), but if someone can throw some light on the matter, I would be very grateful 🙏

I've been studying a fair amount and I think that I'm pretty close to solve it. Right now, I got this to run:

from tf2_gnn.layers.message_passing.rgat import RGAT
from tf2_gnn.layers.message_passing.message_passing import MessagePassingInput
from keras import Model, layers
import tensorflow as tf
inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 7)),name="Input_X")
inputLayer_A1 = layers.Input(shape=tf.TensorShape(dims=2),name="Input_A1", dtype=tf.int32)
inputLayer_A2 = layers.Input(shape=tf.TensorShape(dims=2),name="Input_A2", dtype=tf.int32)
inputLayer_A3 = layers.Input(shape=tf.TensorShape(dims=2),name="Input_A3", dtype=tf.int32)
rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
                    'message_activation_before_aggregation': False,
                    'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(MessagePassingInput(inputLayer_X,
                                                                                                               [inputLayer_A1, inputLayer_A2, inputLayer_A3]))
modelo = Model([inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3], rgatLayer_1, name="The_model")
modelo.summary()

Model: "The_model"


Layer (type) Output Shape Param # Connected to

=================================================================================
Input_X (InputLayer) [(None, None, 7)] 0 []

Input_A1 (InputLayer) [(None, 2)] 0 []

Input_A2 (InputLayer) [(None, 2)] 0 []

Input_A3 (InputLayer) [(None, 2)] 0 []

RGAT_1 (RGAT) (None, 10) 270 ['Input_X[0][0]',
'Input_A1[0][0]',
'Input_A2[0][0]',
'Input_A3[0][0]']

=================================================================================
Total params: 270
Trainable params: 270
Non-trainable params: 0


While it's usable, it's far from ideal, given that it wouldn't work with multiple graphs at the same time (which is what I need). Of course, I tried expanding the input shape of the adjacency matrixes, but the RGAT layer seems to just be able to work with one graph at a time. Because of that, I'm now searching for a workaround (that at least processes multiple graph sequentially).
I'll update as soon as I find it. However, if anyone can help, I would appreciate it 👀

I think I'm almost there, but it's getting trickier. I have 2 relevant attempts. One is this:

from tf2_gnn.layers.message_passing.rgat import RGAT
from tf2_gnn.layers.message_passing.message_passing import MessagePassingInput
from keras import Model, layers
import tensorflow as tf

inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 7)),name="Input_X")
inputLayer_A1 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A1", dtype=tf.int32)
inputLayer_A2 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A2", dtype=tf.int32)
inputLayer_A3 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A3", dtype=tf.int32)
rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
                    'message_activation_before_aggregation': False,
                    'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")
lambdaLayer_1 = layers.Lambda(lambda x: tf.map_fn(lambda y: rgatLayer_1(MessagePassingInput(y[0],[y[1],y[2],y[3]])),
                                                  (x[0],x[1],x[2],x[3]), dtype=tf.float32),
                              name="Lambda_1")((inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3))
modelo = Model([inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3], lambdaLayer_1, name="The_model")
modelo.summary()

Which returns:

ValueError: Exception encountered when calling layer "Lambda_1" (type Lambda).

The following Variables were created within a Lambda layer (Lambda_1)
but are not tracked by said layer:
<tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_0/kernel:0' shape=(7, 10) dtype=float32>
<tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_0/Edge_attention_parameters_0:0' shape=(5, 4) dtype=float32>
<tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_1/kernel:0' shape=(7, 10) dtype=float32>
<tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_1/Edge_attention_parameters_1:0' shape=(5, 4) dtype=float32>
<tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_2/kernel:0' shape=(7, 10) dtype=float32>
<tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_2/Edge_attention_parameters_2:0' shape=(5, 4) dtype=float32>
The layer cannot safely ensure proper Variable reuse across multiple
calls, and consquently this behavior is disallowed for safety. Lambda
layers are not well suited to stateful computation; instead, writing a
subclassed Layer is the recommend way to define layers with
Variables.

Call arguments received:
• inputs=('tf.Tensor(shape=(None, None, 7), dtype=float32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)')
• mask=None
• training=None

Tried to create a custom layer, but it's fairly difficult for me, so I'm searching for other options.
The second attempt is this one:

from tf2_gnn.layers.message_passing.rgat import RGAT
from tf2_gnn.layers.message_passing.message_passing import MessagePassingInput
from keras import Model, layers
import tensorflow as tf

inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 7)),name="Input_X")
inputLayer_A1 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A1", dtype=tf.int32)
inputLayer_A2 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A2", dtype=tf.int32)
inputLayer_A3 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A3", dtype=tf.int32)

lambdaLayer_1 = layers.Lambda(lambda x: tf.map_fn(lambda y: MessagePassingInput(y[0],[y[1],y[2],y[3]]),
                                                  (x[0],x[1],x[2],x[3]), dtype=MessagePassingInput),
                              name="Lambda_1")((inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3))
rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
                    'message_activation_before_aggregation': False,
                    'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(lambdaLayer_1)
modelo = Model([inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3], rgatLayer_1, name="The_model")
modelo.summary()

And this gives me:

TypeError: Exception encountered when calling layer "Lambda_1" (type Lambda).

Cannot convert value <class 'tf2_gnn.layers.message_passing.message_passing.MessagePassingInput'> to a TensorFlow DType.

Call arguments received:
• inputs=('tf.Tensor(shape=(None, None, 7), dtype=float32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)')
• mask=None
• training=None

I'll keep pushing, but I really hope that someone could lend me a hand 😰

mmjb commented

I'm not famliar with Keras and hence can't really help on that front. However, it seems to me that you're unfamiliar with how batching is usually performed in (sparse) GNN implementations: the idea is to represent a batch of graphs as a single graph of disconnected components. As information is only exchanged along edges, these two views are equivalent.

Suitable code to batch graphs like this can be found in https://github.com/microsoft/tf2-gnn/blob/master/tf2_gnn/data/graph_dataset.py#L192-L246.

I'll give it a look in the future. For now, I think it will be best that I work on other stuff (I'll update when I find a confirmed solution 👍)