lucidrains/x-transformers

ContinuousTransformer num_memory_tokens bug

pfeatherstone opened this issue · 1 comments

Here is a repro:

lm = ContinuousTransformerWrapper(
    dim_in              = 4,
    dim_out             = 256+3,
    max_seq_len         = 0,
    num_memory_tokens   = 20,
    attn_layers = Decoder(
        dim = 512,
        depth = 4,
        heads = 4,
        rotary_pos_emb  = True,
        attn_flash      = True,
        use_scalenorm   = True,
        attn_onnxable   = True,
        shift_tokens    = 1
    )
)

x = torch.randn(2, 1024, 4)
l = torch.randint(100, x.shape[1], size=(x.shape[0],))
m = torch.arange(x.shape[1]).unsqueeze(0) < l.unsqueeze(-1)
x = lm(x, mask=m)

should be fixed