LowinLi/transformers-stream-generator

sample_stream has errors when eos_token_id is a list more than one elements

llltttppp opened this issue · 0 comments

main.py line 987 should be
image
origin code will multiply unfinished_sequences with a array which elements not only the 0 or 1

for example
eos_token_id = [1,2]
next_tokens = torch.LongTensor([1,2,3,4,5])
sum(next_tokens != i for i in eos_token_id).long() will be [1,1,2,2,2] which is not right