lucidrains/x-transformers

[Bug] ContinuousTransformerWrapper - return_mems doens't work

Closed this issue · 1 comments

Repro:

net = ContinuousTransformerWrapper(
    dim_out             = 259,
    max_seq_len         = 0,
    num_memory_tokens   = 20,
    max_mem_len         = 100,
    attn_layers = Decoder (
        dim             = 512,
        depth           = 6,
        heads           = 4,
        rotary_pos_emb  = True,
        shift_tokens    = 1,
        attn_flash      = True,
        attn_onnxable   = True,
        use_rmsnorm     = True,
        sandwich_norm   = True
    )
)

x = torch.randn(1, 1024, 512)
m = torch.randn(1, 1024) > 0
mems = [torch.zeros(1, 100, 512) for _ in range(6)]
logits, mems = net(x, mask=m, mems=mems, return_mems=True)
print(logits.shape)
print([m.shape for m in mems])

if you set mems=None instead, it works

@pfeatherstone yea i see

the mems are typically used in transformer-xl like recurrence with causal mask, and doesn't require the key padding mask you are passing in

however, i fixed it just to make it complete