Dao-AILab/flash-attention

Any plan for support paged attention?

donglinz opened this issue · 13 comments

First of all, thank you for the great work!

Is there any plan to support paged kv cache in non-contiguous memory? For instance, in flash_attn_with_kvcache?

tridao commented

It's not the highest priority at the moment. Does the implementation from vLLM not work well?

It's not the highest priority at the moment. Does the implementation from vLLM not work well?

Their scope is slightly different from flash_attn_with_kvcache, which only supports decoding (one token for each batch) I suppose. In many scenarios like speculative decoding. flash_attn_with_kvcache is preferable as it can compute multiple tokens for each batch in parallel.

tridao commented

Is vLLM planning to implement a version that can support more than 1 token?

tridao commented

Does it make sense to have paged KV cache as a standalone function without all the cache management kernels (in vLLM)? How would one use paged KV cache without a cache manager to copy / update the pages?

Is vLLM planning to implement a version that can support more than 1 token?

I have no information on that but I can ask them under the vllm repo. vllm-project/vllm#1598 for reference.

Does it make sense to have paged KV cache as a standalone function without all the cache management kernels (in vLLM)? How would one use paged KV cache without a cache manager to copy / update the pages?

Yes, it is cache manager dependent as flash attention and vllm using different kv cache formats ([B,L,H,D] vs [n_blocks, H, D//x, block_size, x]).

But I think it should be fine as the biggest obstacle on my side is I cannot find a set of kernels that support both paged prefill and paged decode. The cache manager is not a big issue for me because it can be implemented in ~100 lines of python code. (As a user) as long as I have the kernels, I would gladly implement a cache manager by myself that fits the kernel format.

tridao commented

I'm not sure I understand what paged prefill mean, can you say more?
During prefill, the KV cache are calculated as the output of the (nn.Linear) K_proj and V_proj. This is a contiguous memory blob. This contiguous memory blob then can be use for attention during prefill as usual (e.g. calling flash_attn).
I assume vLLM would then copy this contiguous memory blob to different blocks in preparation for decoding?

tdene commented

@donglinz did you ever find a solution? I noticed that you closed vllm-project/vllm#1598.

For the prefill, no cache will be used. I just replaced xformers with FA as xformers does not support MQA/GQA and found attention caculation (softmax(Q @ K^T * softmax_scale) @ V) latency is reduced 2+ times. More details can be found on vllm-project/vllm#1880

For decode stage, we should either rewrite the paged attention kernel in vllm, or modify the FlashAttention kernel to support paged KV cache. I have not evaluation the workload yet.

tridao commented

flash-attn now supports paged KV cache as of v2.5.0

flash-attn now supports paged KV cache as of v2.5.0
@tridao I still wonder
How would one use paged KV cache without a cache manager to copy / update the pages?

You'd need to implement your own cache manager

For those who are interested, here's a simple cache manager: https://github.com/tspeterkim/paged-attention-minimal/