google-deepmind/sonnet

Model Save and Load

yiwc opened this issue · 1 comments

yiwc commented

When we load a model like this, can we transfer it as a sonnet module object? I noticed it is a _UserObject module.

loaded = tf.saved_model.load("/tmp/example_saved_model")

I don't think there is a public API in TF allowing us to hook saved_model.load to recreate your module as a Sonnet module, but it should be possible do this in a wrapper:

class RestoredModule(snt.Module):
  def __init__(self, obj):
    super().__init__()
    self.obj = obj
    self._all_variables = list(obj.signatures['serving_default'].variables)

  def __call__(self, *args):
    return self.obj(*args)

def load_snt(path: str) -> RestoredModule:
  obj = tf.saved_model.load(path)  # NOTE: This is a _UserObject from TF.
  return RestoredModule(obj)

Should be straightforward to use:

class MyModel(snt.Module):
  @snt.once
  def create_vars(self, x):
    self.w = tf.Variable(tf.ones([x.shape[-1], 10]), name='w')
    self.b = tf.Variable(tf.zeros([10]), trainable=False)

  @tf.function(input_signature=[tf.TensorSpec([1, 1])])
  def __call__(self, x):
    self.create_vars(x)
    return tf.matmul(x, self.w) + self.b

m = MyModel()
m(tf.ones([1, 1]))
tf.saved_model.save(m, '/tmp/model/')

r = load_snt('/tmp/model/')  # Use our special loader.
r(tf.ones([1, 1]))
assert isinstance(r, snt.Module)
assert not isinstance(r, MyModel)
assert len(r.variables) == 2
assert len(r.trainable_variables) == 1