Dao-AILab/flash-attention

Inappropriate Number of Splits Predicted by determine_num_splits function in flash_attn_with_kvcache ( Non-paged kv cache )

izhuhaoran opened this issue · 4 comments

I have been working with the vllm &flash-attention library and have encountered an issue where the determine_num_splits function seems to predict an inappropriate number of splits when using non-page kv cache in flash_attn_with_kvcache function.

The num_splits calculated by the determine_num_splits function is closely affected by the following parameters: batch_size:int, num_heads:int, head_size:int, seqlen_q:int, seqlen_k:int, num_SMs:int. During the decoding stage, seqlen_q is 1, which makes the value of seqlen_k quite crucial. In my tests, where the maximum length per sequence is 131072, I preallocate a non-paged length of 131072 for each sequence's cache. I then specify the actual cache length for each sequence using cache_seqlens params in flash_attn_with_kvcache func. However, in the determine_num_splits function, seqlen_k is set to kcache.size(1), and for non-page mode, it can equal 131072. This value seems to overlook the current seqlen, leading to an inappropriate number of num_splits, consequently causing performance loss.

My test script, with default values for batch_size and others, the following is results and thinking :
benchmark_flash_attention.txt

  • When num_splits defaults to 0, the page kernel execution time is 72.252 us, while the non-page is 129.314 us, which indicates a significant performance gap.

  • When adjusting num_splits for non-paged mode with the following code (and removing the num_splits < 128 torch check):

num_splits = determine_num_splits(num_seqs, num_kv_heads, head_size, 1, seq_len, num_SMs)
num_splits = cache_max_seq_len // (seq_len // num_splits)

# The above code considers that after calculating `num_splits` based on `seq_len`, 
# it proportionally determines how many chunks the entire `cache_max_seq_len` should be divided into.

The non-page kernel performance increases to 78.535 us, much closer to the paged kernel performance. (May due to kernel limitations, the calculated num_splits is inaccurate, reducing the usefulness of this performance metric.)

I am aware that my aforementioned approach to determining num_splits is merely a simple attempt and lacks representativeness. But I think, the num_splits calculation in its current state may not be suitable for scenarios with sequence lengths significantly smaller than the allocated cache size, in non-paged memory scenarios during the decoding phase.

Thank you for your attention to this issue. I am looking forward to your feedback on how this can be addressed to optimize the performance of the project

It's just a heuristic to determine num_splits. In this case it doesn't work super well.
We can't use the info in cache_seqlens since that's on GPU and doing a sync from GPU -> CPU (to determine num_splits) would be very expensive.

Got it, Thank you for quick feedback.

Are you suggesting a different heuristic based on cache_max_seq_len?
When would that be better / worse than the current heuristic?

Based on my current testing, when cache_max_seq_len is significantly larger than seq_len, the non-page with kv_cache attention kernel's performance with the current heuristic method shows a considerable gap compared to direct page attention with kv-cache.

However, when testing various num_splits ranging from 1 to 128, there exist good num_splits configurations that yield performance comparable to direct page attention.

This may suggest that a different heuristic approach, perhaps leveraging information such as cache_max_seq_len, might be necessary under this conditions. Unfortunately, I currently unknown how to design such a heuristic strategy.