ROCm/pytorch

Add support for memory efficient attention for AMD/ROCm

Looong01 opened this issue ยท 1 comments

๐Ÿš€ The feature, motivation and pitch

Enable support for Flash Attention Memory Efficient and SDPA kernels for AMD GPUs.

At present using these gives below warning with latest nightlies (torch==2.4.0.dev20240413+rocm6.0, pytorch-triton-rocm 3.0.0+0a22a91d04):

/site-packages/diffusers/models/attention_processor.py:1117: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:505.)

Alternatives

Users cannot use the native PyTorch APIs with memory efficient attention.

Additional context

No response

Hi,

Not sure what is the status, but looks like AMD has been working on it: pytorch#114309