google-deepmind/graph_nets

Passing kwargs to modules.GraphNetwork

Closed this issue · 2 comments

I am trying to pass an "is_training" flag to my node_function to be able to have Dropout and BatchNorm during training but not during validation.
I saw your last update #122 was tackling exactly this, but I am not sure how to use it.
A simple example code:

def MLP_model(is_training=False):
    if is_training:
      layers = [snt.Linear(256),
                tf.nn.relu,
                tf.keras.layers.Dropout(rate=0.5),
                snt.Linear(256),
                tf.nn.relu,
                tf.keras.layers.Dropout(rate=0.5),
                snt.Linear(40)
                ]
    else:
      layers = [snt.Linear(256, with_bias=True),
                tf.nn.relu,
                snt.Linear(256),
                tf.nn.relu,
                snt.Linear(40)
                ]
    return snt.Sequential(layers)

gnn = gn.modules.GraphIndependent(node_model_fn=lambda: MLP_model())
out = gnn(x_train, node_model_kwargs={'is_training':True})

This yields to the error:

TypeError: __call__() got an unexpected keyword argument 'is_training'

Also, when following the instructions in the README for installing with pip, the version that is installed does not have the "Passing kwargs" update.

Thank you for your message.

when following the instructions in the README for installing with pip, the version that is installed does not have the "Passing kwargs" update.

That's is expected, while the feature was added at the current dev version, an official new version has not been released, and hence not yet available in pypi, so to get this feature you would need to install directly from github. Here's some examples of how to install directly from github.

This yields to the error:

The problem is that it is trying to pass is_training to snt.Sequential, which does not take it as an argument. I think this may work if you are using TF1/Sonnet1 (but won't probably work in Sonnet 2).

def MLP_model():
  def model(is_training=False):
    if is_training:
      layers = [snt.Linear(256),
                tf.nn.relu,
                tf.keras.layers.Dropout(rate=0.5),
                snt.Linear(256),
                tf.nn.relu,
                tf.keras.layers.Dropout(rate=0.5),
                snt.Linear(40)
                ]
    else:
      layers = [snt.Linear(256, with_bias=True),
                tf.nn.relu,
                snt.Linear(256),
                tf.nn.relu,
                snt.Linear(40)
                ]
    return snt.Sequential(layers)
  return model

But what you should really do is to build your own Sonnet model that takes is_training that parameter (e.g. in Sonnet 2):

class MyModule(snt.Module):

  def __init__(self, name=None):
    super(MyModule, self).__init__(name=name)

  @snt.once
  def _initialize(self, x):
    self._linear_layers = [snt.Linear(256), snt.Linear(256), snt.Linear(40)]

  def __call__(self, x, is_training):
    self._initialize(x)
    next_input = x
    for layer in self._linear_layers[:-1]:
      next_input = tf.nn.relu(layer(next_input))
      if is_training:
        next_input = tf.keras.layers.Dropout(rate=0.5)(next_input)

    return self._linear_layers[-1](next_input)

And then just:

gn.modules.GraphIndependent(node_model_fn=MyModule)

Note for your particular example since you are just essentially building an MLP you could just:
gn.modules.GraphIndependent(node_model_fn=lambda: snt.nets.MLP(outputs_size=[256, 256, 49], dropout_rate=0.5))

Since snt.nets.MLP already has an is_training attribute.

That is clear! Thank you!