Dao-AILab/flash-attention

Adding support for sqrt of softmax scores

snarayan21 opened this issue · 2 comments

Hi, we're using an attention variant that takes the square root of attention scores (square root of softmax output) so that the attention mechanism is variance-preserving. Without the square root, the attention scores will not preserve the variance.

I've been able to implement this in triton for the original flash attention (albeit an older triton version and modified to enable ALiBi), but could use some help with implementing this as an option in FA2 directly. The speedup from using FA2 is significant and we'd like to get the best performance for model training. Attached is the triton implementation which highlights the necessary changes.

I'd appreciate any help with this!

That's a cool idea! Unfortunately I don't have cycles to do code review.

Yeah apologies, I'll try filing something later that's smaller and much more specific. Thanks!