tensorflow/text

Error when checkpointing a dataset that uses SentencepieceTokenizer

Opened this issue · 1 comments

I am running into a error when checkpointing a tf.data.Dataset iterator that uses a SentencepieceTokenizer for tokenization. It fails with:

tensorflow.python.framework.errors_impl.FailedPreconditionError: {{function_node __wrapped__SerializeIterator_device_/job:localhost/replica:0/task:0/device:CPU:0}} SentencepieceTokenizeOp is stateful. [Op:SerializeIterator] name:

As a result I cannot checkpoint datasets that use SentencepieceTokenizer. Is there a fix of work-around that would resolve the issue for me? I saw

ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("SentencepieceTokenizeOp");
which makes it looks like this supposed to be possible.

Code to reproduce the issue:

import tensorflow as tf
import tensorflow_text as tf_text

  with open("/path/to/tokenizer.model", "rb") as f:
      sp_model = f.read()
  tokenizer = tf_text.SentencepieceTokenizer(sp_model)
  ds = tf.data.Dataset.from_tensor_slices(dict(data=["ex1", "ex2", "ex3",]))

  def _map(ex):
      return dict(data=tokenizer.tokenize(ex["data"]))

  ds: tf.data.Dataset = ds.map(_map)
  iterator = iter(ds)
  ckpt = tf.train.Checkpoint(iterator=iterator)
  ckpt.write("/tmp/iterator")

colab:
https://colab.research.google.com/drive/1kGYP4GJ2YVGBVQaxNzcIm1M3VxO9yRse?authuser=1#scrollTo=nZ5PVQk-BRP7

In case anyone else has this issue, one workaround is to use tf.numpy_function with regular python tokenizer while setting stateful=False.