Model Save and Load
yiwc opened this issue · 1 comments
yiwc commented
When we load a model like this, can we transfer it as a sonnet module object? I noticed it is a _UserObject module.
loaded = tf.saved_model.load("/tmp/example_saved_model")
tomhennigan commented
I don't think there is a public API in TF allowing us to hook saved_model.load
to recreate your module as a Sonnet module, but it should be possible do this in a wrapper:
class RestoredModule(snt.Module):
def __init__(self, obj):
super().__init__()
self.obj = obj
self._all_variables = list(obj.signatures['serving_default'].variables)
def __call__(self, *args):
return self.obj(*args)
def load_snt(path: str) -> RestoredModule:
obj = tf.saved_model.load(path) # NOTE: This is a _UserObject from TF.
return RestoredModule(obj)
Should be straightforward to use:
class MyModel(snt.Module):
@snt.once
def create_vars(self, x):
self.w = tf.Variable(tf.ones([x.shape[-1], 10]), name='w')
self.b = tf.Variable(tf.zeros([10]), trainable=False)
@tf.function(input_signature=[tf.TensorSpec([1, 1])])
def __call__(self, x):
self.create_vars(x)
return tf.matmul(x, self.w) + self.b
m = MyModel()
m(tf.ones([1, 1]))
tf.saved_model.save(m, '/tmp/model/')
r = load_snt('/tmp/model/') # Use our special loader.
r(tf.ones([1, 1]))
assert isinstance(r, snt.Module)
assert not isinstance(r, MyModel)
assert len(r.variables) == 2
assert len(r.trainable_variables) == 1