Bugs in generation with cache and seq_start_pos
LouChao98 opened this issue · 8 comments
F.scaled_dot_product_attention
set a top-left corner (described here https://github.com/Dao-AILab/flash-attention#changes-in-v21-compared-to-v20, pytorch implements a v20 style mask), such that when cache is enabled and len(q)=1 and len(k)=n.causal=True
is inproperly mask out caches. A workaround could beis_causal = causal and (q.shape[2] >= k.shape[2])
when callingF.scaled_dot_product_attention
(may break code depends on the original behaviour)x-transformers/x_transformers/x_transformers.py
Line 1481 in 013e60a
seq_start_pos
is None.- The left_pad is fed into self-attn block as
context_mask
.
x-transformers/x_transformers/x_transformers.py
Lines 1194 to 1201 in 013e60a
x-transformers/x_transformers/x_transformers.py
Lines 1266 to 1267 in 013e60a
Attention
,context_mask
is ignored due to the absence of context.
x-transformers/x_transformers/x_transformers.py
Lines 383 to 384 in 013e60a
i=1, j=n
will get[0, -1, -2, ..]
but we expect[..., -2, -1, 0]
. A workaround could bereturn self.bias[..., -i:, -j:]
. (but like 1. I am not sure how this change will affect previous code)x-transformers/x_transformers/x_transformers.py
Line 1195 in 013e60a
x.shape[-1]
->x.shape[-2]
.x.shape[-1]
is the hidden size.
After solving above bugs, the code passes my test cases: overfit BLOOM and Llama to 2 sentences and regenerate them with a prefix.
@LouChao98 thank you! i believe i have addressed all the issues, but could be still one or two remaining bugs
let me know if 1.22.18 breaks parity with what you have already done locally
@LouChao98 i'm curious, but are you a phd student or a research scientist / engineer?
The mask caused by seq_start_pos
will finally cause a whole row to be masked out (i.e., rows corresponding to a padding token). F.scaled_dot_product_attention
will produce NaN at these positions. There are issues on this in PyTorch:
pytorch/pytorch#103749
pytorch/pytorch#41508
Until PyTorch fix this, we need to allow padding tokens attend to something (but still prevent them from being attended). A workaround could be adding mask.masked_fill_(~mask.any(-1, keepdim=True), True)
between L194 and L195:
x-transformers/x_transformers/attend.py
Lines 192 to 195 in 4465de5
I am a phd student
yes you are right! will fix tomorrow morning
Also I thought so! You are very sharp
@LouChao98 made the change, it should ensure at least one token is attended to, and mask the output too
let me know if you see any further issues
thank you!