Dao-AILab/flash-attention

Numerical difference between flash_attn_varlen_kvpacked_func and vanilla x-attention implementation

rafaelvalle opened this issue · 7 comments

Thank you for your work on flash-attention.
I noticed numerical differences between flash_attn_varlen_kvpacked_func and vanilla implementation of x-attention below.
In autoregressive normalizing flows, this difference is large enough to produce high invertibility error when computing invertibility tests, which are in the order of 1e-10 average error for the vanilla implementation but approximately 1 for flash-attention-2.

I provided below minimal code that runs both flash_attn_varlen_kvpacked_func and the vanilla setup, both with scale=1, p_dropout=0.0, and deterministic=True.

Do you have an explanation on the source of differences?

import torch
from torch import nn
from torch.cuda import amp
import torch.nn.functional as F
from flash_attn import flash_attn_varlen_kvpacked_func


def attn_flash(query: torch.tensor, memory: torch.tensor) -> torch.tensor:
    """
    query <float tensor> (B, Tq, d_in)
    memory <float tensor> (B, Tkv, d_in)
    y <float tensor> (B, Tq, d_out)
    """
    Bq, Tq, _ = query.shape
    Bkv, Tkv, _ = memory.shape
    q = q_net(query).reshape(Bq, Tq, n_heads, d_head)
    kv = kv_net(memory).reshape(Bkv, Tkv, 2, n_heads, d_head)
    q = q.reshape(-1, n_heads, d_head)
    kv = kv.reshape(-1, 2, n_heads, d_head)
    lengths_q = torch.cuda.LongTensor([Tq])
    lengths_k = torch.cuda.LongTensor([Tkv], device='cuda')
    cu_seqlens_q = F.pad(lengths_q.cumsum(0), (1, 0), value=0).to(torch.int32)
    cu_seqlens_k = F.pad(lengths_k.cumsum(0), (1, 0), value=0).to(torch.int32)
    max_seqlen_q = torch.max(lengths_q)
    max_seqlen_k = torch.max(lengths_k)
    y = flash_attn_varlen_kvpacked_func(
        q.bfloat16(), kv.bfloat16(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
        max_seqlen_k, dropout_p=0.0, causal=False, alibi_slopes=None,
        softmax_scale=1.0, deterministic=True)
    y = y[None].reshape(B, Tq, -1)
    return y


def attn_naive(query: torch.tensor, memory: torch.tensor) -> torch.tensor:
    """
    query <float tensor> (B, Tq, d_in)
    memory <float tensor> (B, Tkv, d_in)
    y <float tensor> (B, Tq, d_out)
    """
    B, T, _ = query.shape
    Tkv = memory.shape[1]
    q = q_net(query)
    k, v = kv_net(memory).chunk(2, dim=2)

    # (B, T, nh * dh) ->(B, nh, T, dh)
    q = q.view(B, T, n_heads, d_head).transpose(1, 2)
    k = k.view(B, Tkv, n_heads, d_head).transpose(1, 2)
    v = v.view(B, Tkv, n_heads, d_head).transpose(1, 2)

    attn_score = torch.matmul(q, k.transpose(2, 3))
    attn_prob = F.softmax(attn_score, dim=-1)
    y = torch.matmul(attn_prob, v)
    y = y.transpose(1, 2).contiguous().view(B, T, -1)
    return y


torch.manual_seed(1234)
d_in = 4
d_head = 2
n_heads = 8
q_net = nn.Linear(d_in, d_head * n_heads).cuda()
kv_net = nn.Linear(d_in, 2 * d_head * n_heads).cuda()

B, Tq, Tkv = 1, 2, 3
query = torch.randn((B, Tq, d_in), device='cuda')
memory = torch.randn((B, Tkv, d_in), device='cuda')

y_flash = []
with amp.autocast(enabled=True, dtype=torch.bfloat16):
    for t in range(Tq):
        y_flash.append(attn_flash(query[:, :t+1], memory)[:, -1:])
y_flash = torch.cat(y_flash)

y_naive = []
with amp.autocast(enabled=True, dtype=torch.float32):
    for t in range(Tq):
        y_naive.append(attn_naive(query[:, :t+1], memory)[:, -1:])
y_naive = torch.cat(y_naive)

print(y_flash)
print(y_naive)

If your application is very sensitive to numerical error then flash-attn might not be a good fit, mainly because we only support fp16 / bf16 and not fp32.

In the minimal example I’m running both flash attention and vanilla attention in bfloat16. Do you know what’s the difference in implementation that explains the numerical difference?

Then flash-attn should be more accurate than the standard implementation. You want to compare
(flash-attn in bf16 - reference impl in fp32) vs (reference impl in bf16 - reference impl in fp32).
The first error should be smaller than the second error, or at least comparable.

Interesting you mention that: with an autoregressive flow with 2 steps of flow I get 1 order of magnitude smaller error using naive attention with bfloat16 than with flash-attention-2. The error with fp32, as we expected, is around 7e-10, as you can see below.

On the code I shared above, do you see any mis-use of flash-attention-2 in the x-attention setup?

flash-Attention-2
L2 norm of error 87.46025085449219
MSE of error 0.18030864000320435

Naive FP32
L2 norm of error 7.284597813850269e-07
MSE of error 7.799476064995758e-10

Naive BFLOAT16
L2 norm of error 9.892854690551758
MSE of error 0.019161909818649292

Naive FP16
NAN

And thank you for Mamba :-)

The error seems too high, you can try flash_attn_func since it's simpler to call (no need to construct cu_seqlens which might be error prone).

Try to make the test as simpler as possible.

ad8e commented

The test he wrote is actually correct, but it also shows that varlen is working correctly; no difference in acc between naive and flash. Tested on A40, FA 2.4.2.

Flash std_mean of error: (tensor(0.0020, device='cuda:0'), tensor(-0.0008, device='cuda:0'))
Naive std_mean of error: (tensor(0.0021, device='cuda:0'), tensor(-0.0005, device='cuda:0'))

Full code: flashtest.txt

Changes I made from the original testcase: removed the explicit .bfloat16() cast so that I could test fp16 as well (which also worked). Added some diagnostics at the bottom. Neither change affected anything.

@rafaelvalle unless you are receiving different results, this issue can be closed.