`SentencePieceTokenizer` inside a `keras.models.Model` fails to be reconstructed during `keras.saving.load_model()`
Closed this issue · 2 comments
Describe the bug
When a SentencePieceTokenizer
is integrated into a model using the functional API, a.k.a. keras.models.Model
, it cannot be properly reconstructed from a saved model.keras
file.
While untested, I would expect any other custom keras object that relies on load_assets()
to be able to compute an output spec for a given input tensor to exhibit the same behavior.
To Reproduce
https://colab.research.google.com/drive/1XMNYLQrJo25_BkIv8GT02bMJZMjw5RoC?usp=sharing
Refer to cell no. 6
Expected behavior
Proper reconstruction of SentencePieceTokenizer
.
Additional context
When keras.saving.load_model()
is called on a saved Functional model, the model is reconstructed by running a KerasTensor through the model. Because this happens before the vocabulary is loaded via SentencePieceTokenizer.load_assets()
, an error is raised upon encountering the tokenizer in the model.
The above functionality can be found in keras.saving.saving_lib
.
_load_state()
, which is responsible for calling load_assets()
is called on L178, later than deserialize_keras_object()
on L155.
Would you like to help us fix it?
Defining SentencePieceTokenizer.compute_output_spec()
seems to be sufficient to construct the model graph, allowing the loading function to continue to _load_state()
.
Cell no. 3 in the colab notebook is a working example.
After a quick skim of the repository, BytePairTokenizer
, WordPieceTokenizer
, and SentencePieceTokenizer
seem to have vocabularies saved & loaded with save_assets()
& load_assets()
and are affected by this issue.
The aforementioned compute_output_spec()
method (copied below) should work for each of them.
def compute_output_spec(self, input_spec) -> keras.KerasTensor:
return keras.KerasTensor(input_spec.shape + (self.sequence_length,),
dtype=self.compute_dtype,
sparse=not self.sequence_length)
-> #1523
@briango28 can this be marked as fixed?