keras-team/keras-nlp

`SentencePieceTokenizer` inside a `keras.models.Model` fails to be reconstructed during `keras.saving.load_model()`

Closed this issue · 2 comments

Describe the bug

When a SentencePieceTokenizer is integrated into a model using the functional API, a.k.a. keras.models.Model, it cannot be properly reconstructed from a saved model.keras file.

While untested, I would expect any other custom keras object that relies on load_assets() to be able to compute an output spec for a given input tensor to exhibit the same behavior.

To Reproduce

https://colab.research.google.com/drive/1XMNYLQrJo25_BkIv8GT02bMJZMjw5RoC?usp=sharing
Refer to cell no. 6

Expected behavior

Proper reconstruction of SentencePieceTokenizer.

Additional context

When keras.saving.load_model() is called on a saved Functional model, the model is reconstructed by running a KerasTensor through the model. Because this happens before the vocabulary is loaded via SentencePieceTokenizer.load_assets(), an error is raised upon encountering the tokenizer in the model.

The above functionality can be found in keras.saving.saving_lib.
_load_state(), which is responsible for calling load_assets() is called on L178, later than deserialize_keras_object() on L155.

Would you like to help us fix it?
Defining SentencePieceTokenizer.compute_output_spec() seems to be sufficient to construct the model graph, allowing the loading function to continue to _load_state().

Cell no. 3 in the colab notebook is a working example.

After a quick skim of the repository, BytePairTokenizer, WordPieceTokenizer, and SentencePieceTokenizer seem to have vocabularies saved & loaded with save_assets() & load_assets() and are affected by this issue.

The aforementioned compute_output_spec() method (copied below) should work for each of them.

def compute_output_spec(self, input_spec) -> keras.KerasTensor:
    return keras.KerasTensor(input_spec.shape + (self.sequence_length,),
                             dtype=self.compute_dtype,
                             sparse=not self.sequence_length)

-> #1523

@briango28 can this be marked as fixed?