IDEA-CCNL/Fengshenbang-LM

Ziya2-13B-Chat怎么流式输出

Dingxiangxiang opened this issue · 1 comments

Ziya2-13B-Chat怎么流式输出

from transformers.generation.streamers import BaseStreamer

class ChatStreamer(BaseStreamer):
    def __init__(self, tokenizer) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.text_queue = Queue()
        self.stop_signal = None

    def put(self, value):
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("ChatStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]
        token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
        print('token is', token)
        if token.strip() != "<eoa>":
            self.text_queue.put(token)
        else:
            pass

    def end(self):
        self.text_queue.put(self.stop_signal)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.text_queue.get()
        if value == self.stop_signal:
            raise StopIteration()
        else:
            return value


streamer = ChatStreamer(tokenizer=self.tokenizer)

self.model.generate(
    input_ids,
    eos_token_id=self.tokenizer.encode("</s>"),
    streamer=streamer
)