Error when checkpointing a dataset that uses SentencepieceTokenizer
Opened this issue · 1 comments
chrisc36 commented
I am running into a error when checkpointing a tf.data.Dataset
iterator that uses a SentencepieceTokenizer
for tokenization. It fails with:
tensorflow.python.framework.errors_impl.FailedPreconditionError: {{function_node __wrapped__SerializeIterator_device_/job:localhost/replica:0/task:0/device:CPU:0}} SentencepieceTokenizeOp is stateful. [Op:SerializeIterator] name:
As a result I cannot checkpoint datasets that use SentencepieceTokenizer
. Is there a fix of work-around that would resolve the issue for me? I saw
Code to reproduce the issue:
import tensorflow as tf
import tensorflow_text as tf_text
with open("/path/to/tokenizer.model", "rb") as f:
sp_model = f.read()
tokenizer = tf_text.SentencepieceTokenizer(sp_model)
ds = tf.data.Dataset.from_tensor_slices(dict(data=["ex1", "ex2", "ex3",]))
def _map(ex):
return dict(data=tokenizer.tokenize(ex["data"]))
ds: tf.data.Dataset = ds.map(_map)
iterator = iter(ds)
ckpt = tf.train.Checkpoint(iterator=iterator)
ckpt.write("/tmp/iterator")
chrisc36 commented
In case anyone else has this issue, one workaround is to use tf.numpy_function with regular python tokenizer while setting stateful=False
.