syncdoth/RetNet

ValueError: not enough values to unpack (expected 2, got 1)

pathoncyp opened this issue · 3 comments

Hey,

Thank you for this great work!
An error occurred when I used the model to generate text

File "E:\RetNet-main-huggingface\retnet\modeling_retnet.py", line 368, in forward
batch_size, seq_length = input_ids.shape
ValueError: not enough values to unpack (expected 2, got 1)

Dune-Z commented

@pathoncyp hi
A quick fix, add batch dimension to generated token

  generated.append(token)
  if early_stopping and (token == eos_token_id).all():
         break
  token = token.unsqueeze(0) # add this line

@pathoncyp hi A quick fix, add batch dimension to generated token

  generated.append(token)
  if early_stopping and (token == eos_token_id).all():
         break
  token = token.unsqueeze(0) # add this line

Thank you very much