Dao-AILab/flash-attention

is fwd_kvcache compatible with torch.compile in 2.7.2post1 ?

vince62s opened this issue · 6 comments

Getting this warning and then many subsequent recompiles because using dynamic shapes (and dynamic=True in torch.compile)

/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:725: UserWarning: Graph break due to unsupported builtin flash_attn_2_cuda.PyCapsule.fwd_kvcache. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.

Sure, would love to see some PR fixing this

@ani300 in case you know how to fix this.

By the way: there is a slight speed regression for inference with kvcache between 2.5.9.post1 and 2.6.1
(4% in my case)

Can you send a short script to reproduce the speed regression? e.g. with this input, 2.5.9.post1 gets XXX seconds and 2.6.1 gets YYY seconds

@vince62s I probably forgot to add the torch.compile() wrapping for this function when I did the rest. I can probably take a stab at it later in the week, as I'm wrapped up with a work deadline until Wednesday

I am lazy so I did not recompile 2.6.1 which takes too long to compile but 2.6.1 and 2.7.2post1 are similar in speed.

import torch
import time
from flash_attn.flash_attn_interface import flash_attn_with_kvcache

def test_flash_attn_with_kvcache():
    # Define tensor dimensions
    batch_size = 32
    num_heads = 16
    head_dim = 64
    seqlen_q = 1
    cache_len = 1024

    torch.cuda.synchronize()
    starttime = time.time()
    for i in range(100000):
        # Generate random tensors for query and cached key/value
        q = torch.randn(batch_size, 1, num_heads, head_dim, dtype=torch.float16, device="cuda")
        k = torch.randn(batch_size, 1, num_heads, head_dim, dtype=torch.float16, device="cuda")
        v = torch.randn(batch_size, 1, num_heads, head_dim, dtype=torch.float16, device="cuda")
        k_cache = torch.randn(batch_size, cache_len, num_heads, head_dim, dtype=torch.float16, device="cuda")
        v_cache = torch.randn(batch_size, cache_len, num_heads, head_dim, dtype=torch.float16, device="cuda")

        # Test for non-causal case
        attn_output = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens=cache_len)
    torch.cuda.synchronize()
    print(time.time() - starttime)
    

# Run the test
if __name__ == "__main__":
    test_flash_attn_with_kvcache()

With 2.5.9post1: 28.4653 sec
With 2.7.2post1: 29.2737 sec

That's 3% but my real world use case says 4% (maybe because I use Rotary cos/sin also)