lucidrains/x-transformers

XL-recurrence with RotaryEmbedding and mems not working correctly.

pfeatherstone opened this issue ยท 34 comments

Note, this follows on from #216

I am trying to do XL-recurrence with:

  • RotaryEmbedding
  • attn_num_mem_kv > 0
  • mems and return_mems

I'm doing a test which checks that the outputs when passing mems=None and mems=torch.zeros(...) are the same. They are not.
I'm using the code below:

lm = ContinuousTransformerWrapper(
    dim_in              = 2,
    dim_out             = 36,
    max_seq_len         = 0,
    max_mem_len         = 100,
    attn_layers = Encoder(
        dim             = 512,
        depth           = 4,
        heads           = 4,
        rotary_pos_emb  = True,
        attn_flash      = True,
        attn_num_mem_kv = 20
    )
)

B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth

x       = torch.randn(B, 1024, 2)
length  = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask    = torch.arange(x.shape[1]).unsqueeze(0) < length.unsqueeze(-1)
mems    = [torch.zeros(x.shape[0], M, D) for _ in range(depth)]

out1, mems1 = lm(x, mask=mask, return_mems=True)
out2, mems2 = lm(x, mask=mask, mems=mems, return_mems=True)
torch.testing.assert_close(out1, out2)
for m1, m2 in zip(mems1, mems2):
    torch.testing.assert_close(m1, m2)

I also tried changing

if exists(input_mask) and exists(mem):
input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = True)

to

if exists(input_mask) and exists(mem):
    attend = torch.any(mem)
    input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)

but that doesn't help.
any ideas?

I also tried changing:

freqs = freqs[-seq_len:, :]

to

freqs = freqs[:seq_len, :]

That made more sense to me. I think this makes the results match a bit better but not perfectly.

If i set:

use_abs_pos_emb=True,
rotary_pos_emb=False

And keep the suggested change

if exists(input_mask) and exists(mem):
input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = True)

to

if exists(input_mask) and exists(mem):
    attend = torch.any(mem)
    input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)

Then it works.
My understanding was that RotaryEmbedding should work in this case. Maybe not.
@lucidrains can you confirm?

@pfeatherstone hey, does the equality work if you turn off rotary embeddings?

If I use

use_abs_pos_emb=True,
rotary_pos_emb=False

with the suggested change it works.

If I use:

rotary_pos_emb=False

it attempts to use AbsolutePositionalEmbedding which i don't really want.

nice yea, i think i may know what's up. will look into it when i find a stretch of free time

Can you give me a hint? I can try figure out the details

@pfeatherstone i think the memories should be kept at negative positions, so say you have 2 memory tokens and 5 main tokens, the positions should be [-1, -2, 0, 1, 2, 3, 4] instead of [0..7). could be wrong, need to reread my code

I will give it a go

@pfeatherstone what is the magnitude of the error?

The absolute error is around 0.008 on average

@pfeatherstone ok, it is likely what i said then, if you meant 'max' instead of 'average'

@pfeatherstone ok, it is likely what i said then, if you meant 'max' instead of 'average'

sorry, it's actually larger. More like 0.4 max absolute difference.

i'll still try what you suggested

@lucidrains Yes it worked!

So the total changes are:

if exists(input_mask) and exists(mem):
    attend = torch.any(mem)
    input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)

at line 882 of x_transformers.py

and

if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
    M = max(list(map(lambda m: m.shape[1] if exists(m) else 0, mems)))
    T = x.shape[1]
    t = torch.arange(-M, T)
    rotary_pos_emb = self.rotary_pos_emb.forward(t)

at line 1257 of x_transformers.py

@pfeatherstone ๐Ÿ‘ ๐Ÿ’ฏ you mvp

want to try submitting a PR?

@pfeatherstone ๐Ÿ‘ ๐Ÿ’ฏ you mvp

want to try submitting a PR?

I can do. Without unit tests, PRs are easy ;)
Only thing is that some of the changes aren't ONNX-export friendly...

Also, the line:

attend = torch.any(mem)

doesn't work if any of the batch items is non-zero. So you would need to pad differently for each batch item. I'm looking into a fix

ok, at the very least you got it working for your case

this isn't really that big of a deal

i'll make the correction for rotary when i find some time

thanks for taking the initiative and working it out

So I've fixed the issue of zero mems is the same as not attending to mems at all, and correct rotary embeddings.
The second issue i've come across is that mems are recorded before the pre-norm layer normalization. Yet, on the next iteration, they are prepended after.
I tested it, and i was getting gibberish. I've fixed the issue by recording new mems exactly where old mems are prepended. Now, i get sensible results. FYI, i'm using sandwich norm which uses pre-LN.

@pfeatherstone ah interesting, you find sandwich norms helpful? i tried it for a contracting project and saw worse results

@pfeatherstone yea i saw your PR, but i think it may need to be broken up. i think the zero mems is better dealt with with a mem_mask input

@pfeatherstone ah interesting, you find sandwich norms helpful? i tried it for a contracting project and saw worse results

To be honest i don't know anymore. The first time I tried it I thought it helped. But it could have been a coincidence.

@pfeatherstone yea i saw your PR, but i think it may need to be broken up. i think the zero mems is better dealt with with a mem_mask input

Ok cool. Though the code does create an appropriate mask. it assumes that all zeros shouldn't be attended to. I think that's a sensible default. Would someone want to explicitly attend to zeros ?

@pfeatherstone i don't think there would be any issue, just that a mem_mask would lead to more flexibility, and solve your problem with needing an initial zero mems, which i assume is onnx related

@pfeatherstone ah interesting, you find sandwich norms helpful? i tried it for a contracting project and saw worse results

To be honest i don't know anymore. The first time I tried it I thought it helped. But it could have been a coincidence.

just rerun twice with a change of a boolean and you'll have your answer

Yeah, It takes a couple days for my models to train. There is a lot of augmentation and therefore randomness. Every time i run an experiment, without changing any parameters, convergence happens at different times and of course i get wildely different results.
So when i'm looking at convergence, it's hard to know if an improvement was sheer luck or a model enhancement. Stability on the other hand is pretty tied to the architecture. In my case, with or without sandwich norm, stability is the same.

@pfeatherstone sg

could you try the latest version? below runs fine for me now

import torch
from x_transformers import ContinuousTransformerWrapper, Encoder
from x_transformers import ContinuousAutoregressiveWrapper

lm = ContinuousTransformerWrapper(
    dim_in              = 2,
    dim_out             = 36,
    max_seq_len         = 0,
    max_mem_len         = 100,
    attn_layers = Encoder(
        dim             = 512,
        depth           = 4,
        heads           = 4,
        rotary_pos_emb  = True,
        attn_flash      = True,
        attn_num_mem_kv = 20
    )
)

B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth

x       = torch.randn(B, 1024, 2)
length  = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask    = torch.arange(x.shape[1]).unsqueeze(0) < length.unsqueeze(-1)
mems    = [torch.zeros(x.shape[0], M, D) for _ in range(depth)]
mem_masks = [torch.zeros(x.shape[0], M, dtype = torch.bool) for _ in range(depth)] # memory mask

out1, mems1 = lm(x, mask=mask, return_mems=True)
out2, mems2 = lm(x, mask=mask, mems=mems, mem_masks = mem_masks, return_mems=True)
torch.testing.assert_close(out1, out2)

for m1, m2 in zip(mems1, mems2):
    torch.testing.assert_close(m1, m2)

@pfeatherstone you let me know what you see when you rerun the sandwich norm experiments. thinking about removing it

ok, i'm going to close this issue, i think it is good now

@pfeatherstone noticed you are using an Encoder instead of a Decoder in your example code. you have a working model based on this idea?

I'm actually using a Decoder. I used Encoder for the repro to make things simpler

@pfeatherstone ahh got it, you are using it correctly then, just checking

@pfeatherstone ahh got it, you are using it correctly then, just checking

Out of interest, why would it not be ok to use this with Encoder. The only difference between Encoder and Decoder is whether the mask is causal (triangular) or not. I use Decoder mainly because I don't want to attend to "future" tokens. which is desirable in a streaming architecture.

@pfeatherstone depends on how you are sampling it