lucidrains/x-transformers

Bugs in generation with cache and seq_start_pos

LouChao98 opened this issue · 8 comments

  1. F.scaled_dot_product_attention set a top-left corner (described here https://github.com/Dao-AILab/flash-attention#changes-in-v21-compared-to-v20, pytorch implements a v20 style mask), such that when cache is enabled and len(q)=1 and len(k)=n. causal=True is inproperly mask out caches. A workaround could be is_causal = causal and (q.shape[2] >= k.shape[2]) when calling F.scaled_dot_product_attention (may break code depends on the original behaviour)
  2. pos_emb = self.pos_emb(x, pos = pos, offset = -seq_start_pos) if not external_pos_emb else pos
    We need check whether seq_start_pos is None.
  3. The left_pad is fed into self-attn block as context_mask.
    if exists(seq_start_pos):
    seq_arange = torch.arange(x.shape[-1], device = x.device, dtype = torch.long)
    left_pad_mask = seq_arange >= seq_start_pos[..., None]
    if exists(self_attn_kv_mask):
    self_attn_kv_mask = self_attn_kv_mask & left_pad_mask
    else:
    self_attn_kv_mask = left_pad_mask
    if layer_type == 'a':
    out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
    But in Attention, context_mask is ignored due to the absence of context.
    input_mask = context_mask if has_context else mask
  4. if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
    return self.bias[..., :i, :j]
    The slice is incorrect when using the cache. i=1, j=n will get [0, -1, -2, ..] but we expect [..., -2, -1, 0]. A workaround could be return self.bias[..., -i:, -j:]. (but like 1. I am not sure how this change will affect previous code)
  5. seq_arange = torch.arange(x.shape[-1], device = x.device, dtype = torch.long)
    x.shape[-1] -> x.shape[-2]. x.shape[-1] is the hidden size.

After solving above bugs, the code passes my test cases: overfit BLOOM and Llama to 2 sentences and regenerate them with a prefix.

@LouChao98 thank you! i believe i have addressed all the issues, but could be still one or two remaining bugs

let me know if 1.22.18 breaks parity with what you have already done locally

@LouChao98 i'm curious, but are you a phd student or a research scientist / engineer?

The mask caused by seq_start_pos will finally cause a whole row to be masked out (i.e., rows corresponding to a padding token). F.scaled_dot_product_attention will produce NaN at these positions. There are issues on this in PyTorch:
pytorch/pytorch#103749
pytorch/pytorch#41508

Until PyTorch fix this, we need to allow padding tokens attend to something (but still prevent them from being attended). A workaround could be adding mask.masked_fill_(~mask.any(-1, keepdim=True), True) between L194 and L195:

if exists(mask) and causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
mask = mask & ~causal_mask
causal = False

I am a phd student

yes you are right! will fix tomorrow morning

Also I thought so! You are very sharp

@LouChao98 made the change, it should ensure at least one token is attended to, and mask the output too

let me know if you see any further issues

thank you!