[Bug] ContinuousTransformerWrapper - return_mems doens't work
Closed this issue · 1 comments
pfeatherstone commented
Repro:
net = ContinuousTransformerWrapper(
dim_out = 259,
max_seq_len = 0,
num_memory_tokens = 20,
max_mem_len = 100,
attn_layers = Decoder (
dim = 512,
depth = 6,
heads = 4,
rotary_pos_emb = True,
shift_tokens = 1,
attn_flash = True,
attn_onnxable = True,
use_rmsnorm = True,
sandwich_norm = True
)
)
x = torch.randn(1, 1024, 512)
m = torch.randn(1, 1024) > 0
mems = [torch.zeros(1, 100, 512) for _ in range(6)]
logits, mems = net(x, mask=m, mems=mems, return_mems=True)
print(logits.shape)
print([m.shape for m in mems])
if you set mems=None
instead, it works
lucidrains commented
@pfeatherstone yea i see
the mems are typically used in transformer-xl like recurrence with causal mask, and doesn't require the key padding mask you are passing in
however, i fixed it just to make it complete