is fwd_kvcache compatible with torch.compile in 2.7.2post1 ?
vince62s opened this issue · 6 comments
Getting this warning and then many subsequent recompiles because using dynamic shapes (and dynamic=True in torch.compile)
/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:725: UserWarning: Graph break due to unsupported builtin flash_attn_2_cuda.PyCapsule.fwd_kvcache. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
Sure, would love to see some PR fixing this
By the way: there is a slight speed regression for inference with kvcache between 2.5.9.post1 and 2.6.1
(4% in my case)
Can you send a short script to reproduce the speed regression? e.g. with this input, 2.5.9.post1 gets XXX seconds and 2.6.1 gets YYY seconds
@vince62s I probably forgot to add the torch.compile() wrapping for this function when I did the rest. I can probably take a stab at it later in the week, as I'm wrapped up with a work deadline until Wednesday
I am lazy so I did not recompile 2.6.1 which takes too long to compile but 2.6.1 and 2.7.2post1 are similar in speed.
import torch
import time
from flash_attn.flash_attn_interface import flash_attn_with_kvcache
def test_flash_attn_with_kvcache():
# Define tensor dimensions
batch_size = 32
num_heads = 16
head_dim = 64
seqlen_q = 1
cache_len = 1024
torch.cuda.synchronize()
starttime = time.time()
for i in range(100000):
# Generate random tensors for query and cached key/value
q = torch.randn(batch_size, 1, num_heads, head_dim, dtype=torch.float16, device="cuda")
k = torch.randn(batch_size, 1, num_heads, head_dim, dtype=torch.float16, device="cuda")
v = torch.randn(batch_size, 1, num_heads, head_dim, dtype=torch.float16, device="cuda")
k_cache = torch.randn(batch_size, cache_len, num_heads, head_dim, dtype=torch.float16, device="cuda")
v_cache = torch.randn(batch_size, cache_len, num_heads, head_dim, dtype=torch.float16, device="cuda")
# Test for non-causal case
attn_output = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens=cache_len)
torch.cuda.synchronize()
print(time.time() - starttime)
# Run the test
if __name__ == "__main__":
test_flash_attn_with_kvcache()
With 2.5.9post1: 28.4653 sec
With 2.7.2post1: 29.2737 sec
That's 3% but my real world use case says 4% (maybe because I use Rotary cos/sin also)