Would you like flash attention?
Birch-san opened this issue · 9 comments
I implemented support for flash attention via xformers:
Sygil-Dev/muse-maskgit-pytorch@main...Birch-san:muse-maskgit-pytorch:sdp-attn
It gets the same result, confirmed via allclose()
(I had to make the absolute tolerance a bit more forgiving, to pass that check, but it's still a very small tolerance).
I also implemented support for torch.nn.functional.scaled_dot_product_attention
. it'll use flash attention when mask is None
, but I guess your mask is usually defined.
even without flash attention: scaled_dot_product_attention
should still be faster than the einsum() * scale
approach, because (IIRC) its math fallback is based on baddbmm
, which fuses the scale factor into the matmul.
in stable-diffusion inference for example: we measured end-to-end image generation via baddbmm
to be ~11% faster than einsum() * scale
on CUDA (and 18% faster on Mac).
@lucidrains is this a contribution you would be interested in receiving as a PR?
@Birch-san omg, yes! thank you Birch-san! pytorch 2.0's scaled_dot_product_attention
optional support would be great
@Birch-san hey, decided to take care of it this morning, happy training! (just set flash = True
on the transformer)
@lucidrains
hmm does this actually run with flash attention, or does it fallback to math mode? I thought none of the torch sdp kernels supported masks.
In my diff, where I demonstrated it working with torch sdp: I tried disabling math fallback and it said there was no eligible kernel.
That’s why I also demonstrated how to do it via xformers, which supports masks.
@Birch-san yea it should, as they actually just ended up use Tri's flash attention (with the correct hardware), which supports masking
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html does 'math' mode not support attention masks? i see attn_mask
in the argument, so that would be strange
math mode supports masks. but it's not flash attention. it's just a baddbmm followed by a softmax and a bmm. no memory efficiency, no IO awareness. that's why it's supported on non-CUDA platforms.
if you disable math sdp attn with the context manager:
you get "No available kernel. Aborting execution.":
that's why I showed how to implement flash attention via xformers instead: to get flash attention with support for masking.
ah, if math
is not following the tiled memory efficient algorithm, that is news to me, and maybe i should fall back to the naive implementation here
xformers uses the 'mem_efficient_attention` which is basically a variant of flash attention. but pytorch 2.0 should detect and use it with the right hardware?
anyways, this is all more confusing than it should be. perhaps a meta-library is in order that can optimally select between all these libraries (including triton's impl)
you can see the implementation of math attention in pytorch's attention.cpp
.
it actually looks like it doesn't even use baddbmm
optimizations (fusing the scale factor and the mask bias), and there's no explicit attempt to describe it as a 3D batched matmul. I wonder if a naïve implementation could actually go faster.
I think there must be some subtlety other than hardware, because for me: xformers 0.0.20 supports masks but we can see torch sdp didn't. perhaps xformers hasn't upstreamed all their capabilities into pytorch sdp?
I also did a naïve implementation of memory-efficient attention, with baddbmm
optimizations:
huggingface/diffusers#1892
@Birch-san yea, makes sense, decided to fallback to my own naive implementation
if they don't have masking support for mem-efficient version, they should def upstream xformers version. masking is very basic.