Ziya2-13B-Chat怎么流式输出
Dingxiangxiang opened this issue · 1 comments
Dingxiangxiang commented
Ziya2-13B-Chat怎么流式输出
amanoooo commented
from transformers.generation.streamers import BaseStreamer
class ChatStreamer(BaseStreamer):
def __init__(self, tokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
self.text_queue = Queue()
self.stop_signal = None
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("ChatStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
print('token is', token)
if token.strip() != "<eoa>":
self.text_queue.put(token)
else:
pass
def end(self):
self.text_queue.put(self.stop_signal)
def __iter__(self):
return self
def __next__(self):
value = self.text_queue.get()
if value == self.stop_signal:
raise StopIteration()
else:
return value
streamer = ChatStreamer(tokenizer=self.tokenizer)
self.model.generate(
input_ids,
eos_token_id=self.tokenizer.encode("</s>"),
streamer=streamer
)