huggingface/tokenizers

Incremental Detokenization

Closed this issue ยท 12 comments

Hello, thank you for building such a great foundational library.

I work on the vllm-project, and we have some nasty, slow code related to the challenges of incremental detokenization for streaming use cases. This is needed to defeat cleanup algorithms in the decode where the tokenizer decides to add a space or not depending on the surrounding ids. Relevant code:

We are trying to optimize this code as it can be expensive for high batch size serving. Before we do this, I was wondering if tokenizers has any plans to handle incremental detokenization internally?

Hey! ๐Ÿค—
Yes, if this is key for you, for sure!

Could you help me identify exactly what that consists on?

Would have a raw decoding, which directly returns the correct incremental string work?

I can work on this as soon as I have input / output in a small snippet of what you want!

I can work on this as soon as I have input / output in a small snippet of what you want!

Thanks @ArthurZucker ! I will send some more details later today

Hey @ArthurZucker, Im looking at the following case. We have an inference server generating tokens one at a time and we want to convert each token into a corresponding string before streaming back to the user

  • We run into the following problem, where the tokenizer removes spaces when detokenizing one at a time:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokens = tokenizer("Hello my name is Robert and I work on vLLM.").input_ids

PROMPT_LEN = 5
detokenized_prompt = tokenizer.decode(tokens[:5], skip_special_tokens=True)

one_by_one = detokenized_prompt
for i in range(PROMPT_LEN, len(tokens)):
    one_by_one += tokenizer.decode(tokens[i], skip_special_tokens=True)

print(f"{one_by_one=}")
# 'Hello my name isRobertandIworkonvLLM.'
  • To solve this issue, we do something that looks like the following
TRAIL_LEN = 4
with_trail = detokenized_prompt
for i in range(PROMPT_LEN, len(tokens)):
    previous = tokenizer.decode(tokens[i-TRAIL_LEN:i-1])
    including = tokenizer.decode(tokens[i-TRAIL_LEN:i])
    delta = len(including) - len(previous)
    new = including[-delta:]
    with_trail += new

print(f"{with_trail=}")
'Hello my name is is Robert and I work on vLLM'

This code is both slow (because we cannot do any batching - we have to do these loops in python) and bug prone. This is also a problem in TGI (I believe the vLLM implementation comes from TGI).

I was wondering if tokenizers had any plan to expose some type of stateful API for online decoding like this. We would definitely adopt this feature if made available!

I do recognize this is a bit complicated, as the tokenizer API would need to become stateful to recreate the functionality we have. So if this is not in scope, we would understand.

Another option if stateful tokenizers are not possible would be to expose an API that accepts the previous 5 tokens and returns the string corresponding to the last token with logic equivalent to the above. This would also simplify things for us, especially if this call could have batching in Rust to speed it up

I think @njhill implemented something similar in IBM's fork of TGI -> https://github.com/IBM/text-generation-inference/blob/main/router/src/decoder.rs

Another option if stateful tokenizers are not possible would be to expose an API that accepts the previous 5 tokens and returns the string corresponding to the last token with logic equivalent to the above. This would also simplify things for us, especially if this call could have batching in Rust to speed it up

this sounds easy TBH!

Regarding the statefulness, I have been trying to make the python API more and more "pythonic" so you can already access things like tokenizer._tokenizer.pre_tokenizer and tokenizer._tokenizer.decoder.add_prefix_space:

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

>>> tokenizer._tokenizer.decoder
Sequence(decoders=[Replace(pattern=String("โ–"), content=" "), ByteFallback(), Fuse(), Strip(content=" ", start=1, stop=0)])

In most, if not all tokenizer what you want to do is this:

from tokenizers.decoders import Replace
tokenizer._tokenizer.decoder = Replace(pattern="โ–", content=" ")

You could also include byte fallback and fuse, but TLDR, removing the forced strip

This gives:

one_by_one=' Hello my name is Robert and I work on vLLM.'

With this you have just an extra space at the beginning.

I can also go on what you asked for, support passing the previous tokens, the issue is that we would have to update probably all decoders, which sounds lesss simple .

https://github.com/IBM/text-generation-inference/blob/main/router/src/decoder.rs

this is nice, but yeah for sure we want to add this to tokenizers instead of everyone having a hard time!

Thanks guys! Is there a demo of how to use the API?

@ArthurZucker

Hi, thanks for your work. I want to know is there a demo that the StreamDecode can start from the generated one like below code?

tokens = tokenizer("Hello my name is Robert and I work on vLLM.").input_ids
# 'Hello my name is Robert' is the input prompt
# ' and I work on vLLM.' is the generated tokens

# how ?
stream.step(tokenizer, tokens[5]) == " and"

Or shoud I do the step operations myself for the input prompt like below ? I think this approach is not very efficient.

# input
stream.step(tokenizer, tokens[0])
stream.step(tokenizer, tokens[1])
stream.step(tokenizer, tokens[2])
stream.step(tokenizer, tokens[3])
stream.step(tokenizer, tokens[4])

# generate
out = stream.step(tokenizer, tokens[5]) # which output ' and'