lucidrains/x-transformers

RotaryEmbedding XPOS doesn't work with mems

pfeatherstone opened this issue · 5 comments

Repro:

lm = ContinuousTransformerWrapper(
    dim_in              = 2,
    dim_out             = 36,
    max_seq_len         = 0,
    max_mem_len         = 100,
    attn_layers = Decoder(
        dim             = 512,
        depth           = 4,
        heads           = 4,
        rotary_xpos     = True,
        attn_flash      = True,
        attn_num_mem_kv = 20
    )
)

B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth

x       = torch.randn(B, 1024, 2)
length  = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask    = torch.arange(x.shape[1])[None,:] < length[:,None]
mems    = [torch.randn(B, M, D) for _ in range(depth)]

out, new_mems = lm(x, mask=mask, mems=mems, return_mems=True)

You get error:

t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
RuntimeError: The size of tensor a (1024) must match the size of tensor b (1124) at non-singleton dimension

Presumably in apply_rotary_pos_emb() we need to add:

scale = scale[-seq_len:, :]

?

As an aside, why is all the RotaryEmbedding decorated with @torch.cuda.amp.autocast(enabled = False) ?
You can remove it with just a couple tweaks and it supports torch.bfloat16.

Also, I think the scale calculation is incorrect when using mems since the positions are off.
You have to use the same trick of starting from negative position.

#234

I believe this fixes it.

great job!