ContinuousTransformer num_memory_tokens bug
pfeatherstone opened this issue · 1 comments
pfeatherstone commented
Here is a repro:
lm = ContinuousTransformerWrapper(
dim_in = 4,
dim_out = 256+3,
max_seq_len = 0,
num_memory_tokens = 20,
attn_layers = Decoder(
dim = 512,
depth = 4,
heads = 4,
rotary_pos_emb = True,
attn_flash = True,
use_scalenorm = True,
attn_onnxable = True,
shift_tokens = 1
)
)
x = torch.randn(2, 1024, 4)
l = torch.randint(100, x.shape[1], size=(x.shape[0],))
m = torch.arange(x.shape[1]).unsqueeze(0) < l.unsqueeze(-1)
x = lm(x, mask=m)
lucidrains commented
should be fixed