Dao-AILab/flash-attention

Are there any plans for supporting an explicit attention mask?

Avelina9X opened this issue · 5 comments

I've noticed that the Triton implementation supports explicit attention bias, which can be used to support arbitrary mask shapes with large negative values, however is there any planned support for explicit (boolean) masks in the CUDA implementation?

I've noticed some requests for features like off-diagonal attention, but an explicit attention mask would be able to facilitate this and any other arbitrary masking scheme - such as XL-Net, attention sinking, landmark attention - without needing to hardcode the attention scheme and enable it with an argument or seperate python interface.

It seems like the PyTorch attention implementation supports custom attention masks and also uses Flash-Attention 2: https://twitter.com/StasBekman/status/1736083447658225665. Though I'm not sure that passing in an attention mask doesn't cause the op to dispatch to a non-FA2 kernel.

If there's attn mask pytorch does not dispatch to FA2 kernel, rather the kernel from xformers.

Thanks for the info @tridao! Is support for arbitrary attention masks on your roadmap? This would be incredibly useful for some encoder-decoder and prefixLM models. Mandatory thank you for your amazing work!

If there's attn mask pytorch does not dispatch to FA2 kernel, rather the kernel from xformers.

Thanks for this valuable tip. No wonder torch.nn.functional.scaled_dot_product_attention does not bring any speed up in my case

I'm looking for bias mask support too, in FA2 and better FA3. Is there a roadmap for this? Thank you~