tf_text.FastSentencepieceTokenizer causes ValueError: cannot create std::vector larger than max_size()
hanneshapke opened this issue · 2 comments
I tried to switch from the tf_text.SentencepieceTokenizer
to tf_text.FastSentencepieceTokenizer
but encountered the issue below.
ValueError Traceback (most recent call last)
<ipython-input-150-df5232e30d34> in <module>
7 # sp_model = open("new.model", "r").read()
8
----> 9 tokenizer = tf_text.FastSentencepieceTokenizer(sp_model)
[/usr/local/lib/python3.8/dist-packages/tensorflow_text/python/ops/fast_sentencepiece_tokenizer.py](https://localhost:8080/#) in __init__(self, model, reverse, add_bos, add_eos)
48
49 def __init__(self, model, reverse=False, add_bos=False, add_eos=False):
---> 50 converted_model = pywrap_model_converter.convert_sentencepiece_model(model)
51 converted_model_detokenizer = pywrap_model_converter.convert_sentencepiece_model_for_decoder(
52 model)
ValueError: cannot create std::vector larger than max_size()
I can load the sentencepiece model proto with tf_text.SentencepieceTokenizer
, but switching to tf_text.FastSentencepieceTokenizer
is causing the issue.
Is there a limitation around the tf_text.FastSentencepieceTokenizer
?
I have tried to trace the C++ code, but didn't find any model limitation.
Here is a Colab Notebook to reproduce the error.
Loading of the sentence piece model works with SetencePiece
and SentencepieceTokenizer
, but not with FastSentencepieceTokenizer
TF Version: 2.11.0
TF Text Version: 2.11.0
SentencePiece Version: 0.1.97
Hi, @hanneshapke. I found the reason why this happens.
Because roberta base tokenizer doesn't need to normalize input sentences (check link 1), so I passed empty precompiled_charsmap
to normalizer_spec
of model protobuf message. ButFastSentencepieceTokenizer
is trying to parse precompiled_charsmap
guessing it isn't empty (check link 2), so model isn't convertible.
I tried to convert model with arbitrary charsmap found in https://github.com/google/sentencepiece/blob/master/src/normalization_rule.h, and it was successful. But it was obvious that the results were different 😅 (Maybe because of unicode normalization)
As you can see below link (3), if we use identity normalizer, I think charsmap is set to empty bytes in sentencepiece codebase.
- tokenizer config of Roberta base: https://huggingface.co/roberta-base/blob/main/tokenizer.json
... "normalizer": null, ...
- Related Tensorflow text code:
text/tensorflow_text/core/kernels/sentencepiece/model_converter.cc
Lines 44 to 61 in 6877342
- sentencepiece code for building charsmap: https://github.com/google/sentencepiece/blob/31656da0c9cccfc47d4f0e69fc32d55faac3e1e9/src/builder.cc#L274-L294
- Sentencepiece model proto (precompiled_charsmap): https://github.com/google/sentencepiece/blob/31656da0c9cccfc47d4f0e69fc32d55faac3e1e9/src/sentencepiece_model.proto#L238-L241