keras-team/tf-keras

Custom Keras RNN with constants changes constants shape when saving

Closed this issue · 3 comments

Constants shape changes from rank 2 during inference to rank 3 tensor during model saving

System Info:

  • TensorFlow version (Cuda): 2.15
  • OS platform and distribution: Windows WSL
  • Python version: 3.11
from tensorflow.python.keras.layers.recurrent import (
    DropoutRNNCellMixin,
    _config_for_enable_caching_device,
    _caching_device,
)
import tensorflow as tf

class RNNWithConstants(
    DropoutRNNCellMixin, tf.keras.__internal__.layers.BaseRandomLayer
):
    def __init__(
        self,
        units,
        activation,
        recurrent_activation,
        dropout,
        recurrent_dropout,
        **kwargs,
    ):
        super(RNNWithConstants, self).__init__(**kwargs)
        self.units = units
        self.dropout = dropout
        self.recurrent_dropout = recurrent_dropout
        self.recurrent_activation = recurrent_activation
        self.cell = tf.keras.layers.GRUCell(
            units=units,
            activation=activation,
            recurrent_activation=recurrent_activation,
            recurrent_dropout=recurrent_dropout,
            dropout=dropout,
        )
        self.state_size = units
        self.output_size = units

    @tf.function
    def call(self, inputs, states, constants):
        print(f"inputs {inputs.shape}")
        print(f"states {states[0].shape}")
        print(f"constants {constants[0].shape}")
        inputs = tf.concat([inputs, constants[0]], axis=-1)  # error due to shape change
        h, _ = self.cell(inputs, states)
        return h, h
    
class ConstantsModel(tf.keras.models.Model):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units 
        self.cell = RNNWithConstants(units, "sigmoid", "sigmoid", 0.1, 0.1)
        self.rnn = tf.keras.layers.RNN(self.cell)

    @tf.function
    def call(self, inputs, constants, training):
        return self.rnn(inputs, constants=constants)
    
const = ConstantsModel(10)
print("initializing...")
_ = const(tf.random.normal(shape=(100, 50, 10)), tf.random.normal(shape=(100, 10)))
print("\nsaving....")
const.save("./const_model")

Relevant Log Output

initializing...
inputs (100, 10)
states (100, 10)
constants (100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)

saving....
inputs (None, 10)
states (None, 10)
constants (None, 10)
inputs (None, 10)
states (None, 10)
constants (None, None, 10)

Keras Issue

@sachinprasadhs,
I was able to reproduce the issue on tensorflow v2.15 with Keras2. Kindly find the gist of it here.

Triage Note: keras-v3 does not support adding constants to the call method, which is different from the tf-keras implementation.

Hi @claCase! Could you try using using the new Keras v3 (.keras) format for saving? Alternatively, we suggest you use .save_weights instead. In general saved model is not as recommended as newer methods.

Hi @grasskin

By utilizing the .keras extension and slightly refactoring the model I was able to successfully save it. Please see the reference code below.
Many thanks! I'm closing the issue.

import tensorflow as tf 
from tensorflow.python.keras.layers.recurrent import (
    DropoutRNNCellMixin,
    _config_for_enable_caching_device,
    _caching_device,
)

tf.keras.saving.get_custom_objects().clear()

@tf.keras.saving.register_keras_serializable(package="custom_package")
class RNNWithConstants(
    DropoutRNNCellMixin, tf.keras.__internal__.layers.BaseRandomLayer
):
    def __init__(
        self,
        units,
        activation,
        recurrent_activation,
        dropout,
        recurrent_dropout,
        **kwargs,
    ):
        super(RNNWithConstants, self).__init__(**kwargs)
        self.units = units
        self.dropout = dropout
        self.recurrent_dropout = recurrent_dropout
        self.recurrent_activation = recurrent_activation
        self.cell = tf.keras.layers.GRUCell(
            units=units,
            activation=activation,
            recurrent_activation=recurrent_activation,
            recurrent_dropout=recurrent_dropout,
            dropout=dropout,
        )
        self.state_size = units
        self.output_size = units

    def build(self, inputs_shape, *args):
        print(inputs_shape)
        super().build(inputs_shape)

    @tf.function
    def call(self, inputs, states, constants, **kwargs):
        print(f"inputs {inputs.shape}")
        print(f"states {states[0].shape}")
        print(f"constants {constants[0].shape}")
        inputs = tf.concat([inputs, constants[0]], axis=-1)
        h, _ = self.cell(inputs, states)
        return h, h

@tf.keras.saving.register_keras_serializable(package="custom_package")
class ConstantsModel(tf.keras.models.Model):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units 
        self.cell = RNNWithConstants(units, "sigmoid", "sigmoid", 0.1, 0.1)
        self.rnn = tf.keras.layers.RNN(self.cell)

    @tf.function
    def call(self, inputs, training):
        inputs, constants = inputs[0], inputs[1]
        return self.rnn(inputs, constants=constants)
   
const = ConstantsModel(10)
print("initializing...")
_ = const((tf.random.normal(shape=(100, 50, 10)), tf.random.normal(shape=(100, 10))), False)
const.save("./const_model.keras")

print("\nloading....")
const2 = tf.keras.models.load_model("./const_model.keras")
print("\ninitializing new model....")
_ = const2((tf.random.normal(shape=(100, 50, 10)), tf.random.normal(shape=(100, 10))), False)
initializing...
(100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)

loading....
(100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)

initializing new model....
inputs (100, 10)
states (100, 10)
constants (100, 10)