RotaryEmbedding XPOS doesn't work with mems
pfeatherstone opened this issue · 5 comments
pfeatherstone commented
Repro:
lm = ContinuousTransformerWrapper(
dim_in = 2,
dim_out = 36,
max_seq_len = 0,
max_mem_len = 100,
attn_layers = Decoder(
dim = 512,
depth = 4,
heads = 4,
rotary_xpos = True,
attn_flash = True,
attn_num_mem_kv = 20
)
)
B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth
x = torch.randn(B, 1024, 2)
length = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask = torch.arange(x.shape[1])[None,:] < length[:,None]
mems = [torch.randn(B, M, D) for _ in range(depth)]
out, new_mems = lm(x, mask=mask, mems=mems, return_mems=True)
You get error:
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
RuntimeError: The size of tensor a (1024) must match the size of tensor b (1124) at non-singleton dimension
pfeatherstone commented
Presumably in apply_rotary_pos_emb()
we need to add:
scale = scale[-seq_len:, :]
?
pfeatherstone commented
As an aside, why is all the RotaryEmbedding decorated with @torch.cuda.amp.autocast(enabled = False)
?
You can remove it with just a couple tweaks and it supports torch.bfloat16
.
pfeatherstone commented
Also, I think the scale
calculation is incorrect when using mems since the positions are off.
You have to use the same trick of starting from negative position.
pfeatherstone commented
I believe this fixes it.
lucidrains commented
great job!