lucidrains/MEGABYTE-pytorch

Minor shape error

anruigu opened this issue · 1 comments

Flagging in case anyone else ran into this:
train.py errored for me initially on line 400 of megabyte.py:
start_tokens, logits = logits[:, 0, :1, :], logits[..., 1:, :]
I reshaped the start_tokens so they're shaped (4, 1, 256) instead of (4, 1, 5, 256) and the code runs fine.

@anruigu oh i forgot to generalize the start token for greater than 2 hierarchies

should be fixed!