Minor shape error
anruigu opened this issue · 1 comments
anruigu commented
Flagging in case anyone else ran into this:
train.py errored for me initially on line 400 of megabyte.py:
start_tokens, logits = logits[:, 0, :1, :], logits[..., 1:, :]
I reshaped the start_tokens so they're shaped (4, 1, 256) instead of (4, 1, 5, 256) and the code runs fine.
lucidrains commented
@anruigu oh i forgot to generalize the start token for greater than 2 hierarchies
should be fixed!