lucidrains/muse-maskgit-pytorch

Would you like flash attention?

Birch-san opened this issue · 9 comments

I implemented support for flash attention via xformers:
Sygil-Dev/muse-maskgit-pytorch@main...Birch-san:muse-maskgit-pytorch:sdp-attn

It gets the same result, confirmed via allclose() (I had to make the absolute tolerance a bit more forgiving, to pass that check, but it's still a very small tolerance).

I also implemented support for torch.nn.functional.scaled_dot_product_attention. it'll use flash attention when mask is None, but I guess your mask is usually defined.
even without flash attention: scaled_dot_product_attention should still be faster than the einsum() * scale approach, because (IIRC) its math fallback is based on baddbmm, which fuses the scale factor into the matmul.
in stable-diffusion inference for example: we measured end-to-end image generation via baddbmm to be ~11% faster than einsum() * scale on CUDA (and 18% faster on Mac).

@lucidrains is this a contribution you would be interested in receiving as a PR?

@Birch-san omg, yes! thank you Birch-san! pytorch 2.0's scaled_dot_product_attention optional support would be great

@Birch-san hey, decided to take care of it this morning, happy training! (just set flash = True on the transformer)

@lucidrains
hmm does this actually run with flash attention, or does it fallback to math mode? I thought none of the torch sdp kernels supported masks.

In my diff, where I demonstrated it working with torch sdp: I tried disabling math fallback and it said there was no eligible kernel.
That’s why I also demonstrated how to do it via xformers, which supports masks.

@Birch-san yea it should, as they actually just ended up use Tri's flash attention (with the correct hardware), which supports masking

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html does 'math' mode not support attention masks? i see attn_mask in the argument, so that would be strange

math mode supports masks. but it's not flash attention. it's just a baddbmm followed by a softmax and a bmm. no memory efficiency, no IO awareness. that's why it's supported on non-CUDA platforms.

if you disable math sdp attn with the context manager:

image

you get "No available kernel. Aborting execution.":

image

that's why I showed how to implement flash attention via xformers instead: to get flash attention with support for masking.

ah, if math is not following the tiled memory efficient algorithm, that is news to me, and maybe i should fall back to the naive implementation here

xformers uses the 'mem_efficient_attention` which is basically a variant of flash attention. but pytorch 2.0 should detect and use it with the right hardware?

Screenshot from 2023-07-15 11-21-11

anyways, this is all more confusing than it should be. perhaps a meta-library is in order that can optimally select between all these libraries (including triton's impl)

you can see the implementation of math attention in pytorch's attention.cpp.
it actually looks like it doesn't even use baddbmm optimizations (fusing the scale factor and the mask bias), and there's no explicit attempt to describe it as a 3D batched matmul. I wonder if a naïve implementation could actually go faster.

I think there must be some subtlety other than hardware, because for me: xformers 0.0.20 supports masks but we can see torch sdp didn't. perhaps xformers hasn't upstreamed all their capabilities into pytorch sdp?

I also did a naïve implementation of memory-efficient attention, with baddbmm optimizations:
huggingface/diffusers#1892

@Birch-san yea, makes sense, decided to fallback to my own naive implementation

if they don't have masking support for mem-efficient version, they should def upstream xformers version. masking is very basic.