Dao-AILab/flash-attention

Any plans to support tree attention mask?

KexinFeng opened this issue · 6 comments

Tree attention mask is already supported in huggingface/transformers: huggingface/transformers#27539
It will be very helpful for the speculative decoding applications. More sepcifically, in flash_attn/flash_attn_interface.py#flash_attn_with_kvcache, the tree attention mask will need to be specified and passed in as an argument.

Do you have any near plans to support it?

Thanks

Related questions: #840, #918

Sure, we'll just need someone to contribute :D

I'm keen to try supporting a generic mask case, like [B, Q, K] bool, and doing conditional execution. Ideally this covers quite a lot of masking cases, but I guess optimised kernels would work better for more structured masks (like Tree).

I don't see much difference between a generic mask and a structured mask. For a tree mask, the mask argument would also be of [B, K, Q]. In the 4d attention mask mentioned above, it's nothing but [b, h, k, q] h being number of head.

If you are able to implement a generic mask, then a structured mask will be ready

What I mean is that for a structured mask you don't necessarily have to create a bool tensor. In the casual case it can be hardcoded in the kernel to ignore j>i+k_cache, which saves a little bit of memory. If its structured the locations you'll visit are predictable.

I see. Yeah, in the causal mask case, indeed the bool tensor mask argument is not required. For the tree attention mask, however, this argument will be inevitable. But probably this doesn't increase much implementation complexity, since the causal mask will internally be converted to such tensor anyway. @thorinf Look forward to your PR!

Hello, sorry for the naive question but:

  1. Why do you need structured masking? can't you do something similar with attention biases?
  2. Are you hoping that you might be able to skip blocks that are entirely masked? or will you still compute attention over the full matrix?

It might help me understand this a bit more :)