Custom Keras RNN with constants changes constants shape when saving
Closed this issue · 3 comments
Constants shape changes from rank 2 during inference to rank 3 tensor during model saving
System Info:
- TensorFlow version (Cuda): 2.15
- OS platform and distribution: Windows WSL
- Python version: 3.11
from tensorflow.python.keras.layers.recurrent import (
DropoutRNNCellMixin,
_config_for_enable_caching_device,
_caching_device,
)
import tensorflow as tf
class RNNWithConstants(
DropoutRNNCellMixin, tf.keras.__internal__.layers.BaseRandomLayer
):
def __init__(
self,
units,
activation,
recurrent_activation,
dropout,
recurrent_dropout,
**kwargs,
):
super(RNNWithConstants, self).__init__(**kwargs)
self.units = units
self.dropout = dropout
self.recurrent_dropout = recurrent_dropout
self.recurrent_activation = recurrent_activation
self.cell = tf.keras.layers.GRUCell(
units=units,
activation=activation,
recurrent_activation=recurrent_activation,
recurrent_dropout=recurrent_dropout,
dropout=dropout,
)
self.state_size = units
self.output_size = units
@tf.function
def call(self, inputs, states, constants):
print(f"inputs {inputs.shape}")
print(f"states {states[0].shape}")
print(f"constants {constants[0].shape}")
inputs = tf.concat([inputs, constants[0]], axis=-1) # error due to shape change
h, _ = self.cell(inputs, states)
return h, h
class ConstantsModel(tf.keras.models.Model):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.cell = RNNWithConstants(units, "sigmoid", "sigmoid", 0.1, 0.1)
self.rnn = tf.keras.layers.RNN(self.cell)
@tf.function
def call(self, inputs, constants, training):
return self.rnn(inputs, constants=constants)
const = ConstantsModel(10)
print("initializing...")
_ = const(tf.random.normal(shape=(100, 50, 10)), tf.random.normal(shape=(100, 10)))
print("\nsaving....")
const.save("./const_model")
Relevant Log Output
initializing...
inputs (100, 10)
states (100, 10)
constants (100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)
saving....
inputs (None, 10)
states (None, 10)
constants (None, 10)
inputs (None, 10)
states (None, 10)
constants (None, None, 10)
@sachinprasadhs,
I was able to reproduce the issue on tensorflow v2.15 with Keras2. Kindly find the gist of it here.
Triage Note: keras-v3 does not support adding constants to the call method, which is different from the tf-keras implementation.
Hi @claCase! Could you try using using the new Keras v3 (.keras) format for saving? Alternatively, we suggest you use .save_weights
instead. In general saved model is not as recommended as newer methods.
Hi @grasskin
By utilizing the .keras extension and slightly refactoring the model I was able to successfully save it. Please see the reference code below.
Many thanks! I'm closing the issue.
import tensorflow as tf
from tensorflow.python.keras.layers.recurrent import (
DropoutRNNCellMixin,
_config_for_enable_caching_device,
_caching_device,
)
tf.keras.saving.get_custom_objects().clear()
@tf.keras.saving.register_keras_serializable(package="custom_package")
class RNNWithConstants(
DropoutRNNCellMixin, tf.keras.__internal__.layers.BaseRandomLayer
):
def __init__(
self,
units,
activation,
recurrent_activation,
dropout,
recurrent_dropout,
**kwargs,
):
super(RNNWithConstants, self).__init__(**kwargs)
self.units = units
self.dropout = dropout
self.recurrent_dropout = recurrent_dropout
self.recurrent_activation = recurrent_activation
self.cell = tf.keras.layers.GRUCell(
units=units,
activation=activation,
recurrent_activation=recurrent_activation,
recurrent_dropout=recurrent_dropout,
dropout=dropout,
)
self.state_size = units
self.output_size = units
def build(self, inputs_shape, *args):
print(inputs_shape)
super().build(inputs_shape)
@tf.function
def call(self, inputs, states, constants, **kwargs):
print(f"inputs {inputs.shape}")
print(f"states {states[0].shape}")
print(f"constants {constants[0].shape}")
inputs = tf.concat([inputs, constants[0]], axis=-1)
h, _ = self.cell(inputs, states)
return h, h
@tf.keras.saving.register_keras_serializable(package="custom_package")
class ConstantsModel(tf.keras.models.Model):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.cell = RNNWithConstants(units, "sigmoid", "sigmoid", 0.1, 0.1)
self.rnn = tf.keras.layers.RNN(self.cell)
@tf.function
def call(self, inputs, training):
inputs, constants = inputs[0], inputs[1]
return self.rnn(inputs, constants=constants)
const = ConstantsModel(10)
print("initializing...")
_ = const((tf.random.normal(shape=(100, 50, 10)), tf.random.normal(shape=(100, 10))), False)
const.save("./const_model.keras")
print("\nloading....")
const2 = tf.keras.models.load_model("./const_model.keras")
print("\ninitializing new model....")
_ = const2((tf.random.normal(shape=(100, 50, 10)), tf.random.normal(shape=(100, 10))), False)
initializing...
(100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)
loading....
(100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)
initializing new model....
inputs (100, 10)
states (100, 10)
constants (100, 10)