triton-lang/triton

Big results difference when using `tl.store`

yzhangcs opened this issue · 3 comments

Hi, I find a big results difference when using tl.store (under bfloat16).

# -*- coding: utf-8 -*-

import torch
import triton
import triton.language as tl


@triton.jit
def attention_nostore_fwd_kernel(
    q,
    k,
    v,
    h,
    o,
    s_qh,
    s_qt,
    s_qd,
    s_hh,
    s_ht,
    H,
    T,
    TD,
    scale,
    BT: tl.constexpr,
    BD: tl.constexpr
):
    i_bh = tl.program_id(0)

    # [BD, BD]
    b_h = tl.zeros([BD, BD], dtype=tl.float32)
    for i in range(0, tl.cdiv(T, BT)):
        p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
        p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))
        p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
        p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))
        p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))

        # [BT, BD]
        b_q = tl.load(p_q)
        b_q = (b_q * scale).to(b_q.dtype)
        # [BD, BT]
        b_k = tl.load(p_k)
        # [BT, BD]
        b_v = tl.load(p_v)

        # [BT, BT]
        b_s = tl.dot(b_q, b_k, allow_tf32=False)
        # [BT, BD]
        b_o = tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
        b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)

        # tl.store(p_h, b_h.to(p_h.dtype.element_ty))
        tl.store(p_o, b_o.to(p_o.dtype.element_ty))

        # [BD, BD]
        b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)


@triton.jit
def attention_store_fwd_kernel(
    q,
    k,
    v,
    h,
    o,
    s_qh,
    s_qt,
    s_qd,
    s_hh,
    s_ht,
    H,
    T,
    TD,
    scale,
    BT: tl.constexpr,
    BD: tl.constexpr
):
    i_bh = tl.program_id(0)

    # [BD, BD]
    b_h = tl.zeros([BD, BD], dtype=tl.float32)
    for i in range(0, tl.cdiv(T, BT)):
        p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
        p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))
        p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
        p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))
        p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))

        # [BT, BD]
        b_q = tl.load(p_q)
        b_q = (b_q * scale).to(b_q.dtype)
        # [BD, BT]
        b_k = tl.load(p_k)
        # [BT, BD]
        b_v = tl.load(p_v)

        # [BT, BT]
        b_s = tl.dot(b_q, b_k, allow_tf32=False)
        # [BT, BD]
        b_o = tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
        b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)

        tl.store(p_h, b_h.to(p_h.dtype.element_ty))
        tl.store(p_o, b_o.to(p_o.dtype.element_ty))

        # [BD, BD]
        b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)


class AttentionFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, store=False):
        batch_size, n_heads, seq_len, d_head = q.shape
        scale = d_head ** -0.5
        BD = q.shape[-1]
        BT = 32
        num_stages = 3 if d_head <= 64 else 2
        num_warps = 4

        h = q.new_empty(batch_size, n_heads, triton.cdiv(seq_len, BT) * BD, BD)
        o = torch.empty_like(q)
        grid = (batch_size * n_heads,)
        kernel = attention_store_fwd_kernel if store else attention_nostore_fwd_kernel
        kernel[grid](
            q, k, v, h, o,
            q.stride(1), q.stride(2), q.stride(3), h.stride(1), h.stride(2),
            n_heads, seq_len, h.shape[2], scale,
            BT=BT, BD=BD,
            num_warps=num_warps,
            num_stages=num_stages
        )
        return o


if __name__ == '__main__':
    B, H, T, D = 2, 8, 1024, 128
    dtype = torch.bfloat16
    torch.manual_seed(42)
    # [batch_size, n_heads, seq_len, d_head]
    q = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
    k = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
    v = torch.randn((B, H, T, D), dtype=dtype, device='cuda')

    print('Testing BFloat16...')
    ref = AttentionFunction.apply(q, k, v, True)
    tri = AttentionFunction.apply(q, k, v, False)
    print(ref[0, 0])
    print(tri[0, 0])
    print('Diff:', (ref - tri).abs().max(), '\n\n')

    print('Testing Float...')
    q, k, v = q.float(), k.float(), v.float()
    ref = AttentionFunction.apply(q, k, v, True)
    tri = AttentionFunction.apply(q, k, v, False)
    print(ref[0, 0])
    print(tri[0, 0])
    print('Diff:', (ref - tri).abs().max(), '\n\n')

I hve pasted the tailored code here for ease of reproduction.
The only differnce between attention_nostore_fwd_kernel and attention_store_fwd_kernel is tl.store(p_h, b_h.to(p_h.dtype.element_ty)), which saves the intermediate results to HBMs, and the output is

Testing BFloat16...
tensor([[ -3.0000,   4.3125,  -5.2812,  ...,  -3.1094,   4.4062,   1.4141],
        [ -4.1875,   8.4375,   7.3750,  ...,  -4.2188,   0.7227,   4.2188],
        [ -8.3750,   9.2500,  -3.0938,  ..., -10.4375,   3.5312,  -1.4688],
        ...,
        [ 33.0000, -45.5000, -27.5000,  ...,  47.5000,  11.7500, -42.5000],
        [-12.8125, -13.3125, -29.5000,  ..., -28.2500, -17.8750,  -8.5625],
        [ -8.0625,  19.0000, -25.1250,  ...,  44.0000,  31.8750,   0.7148]],
       device='cuda:0', dtype=torch.bfloat16)
tensor([[ -3.0000,   4.3125,  -5.2812,  ...,  -3.1094,   4.4062,   1.4141],
        [ -4.1875,   8.4375,   7.3750,  ...,  -4.2188,   0.7227,   4.2188],
        [ -8.3750,   9.2500,  -3.0938,  ..., -10.4375,   3.5312,  -1.4688],
        ...,
        [ 21.0000, -57.2500,  94.0000,  ...,  -6.8125, -43.5000,  -3.3281],
        [ 91.0000,  29.6250,   0.9414,  ...,  15.3750,  -4.5000,  13.4375],
        [  8.0625, -24.6250,  21.8750,  ...,   1.3672, -21.3750,  96.0000]],
       device='cuda:0', dtype=torch.bfloat16)
Diff: tensor(223., device='cuda:0', dtype=torch.bfloat16) 


Testing Float...
tensor([[ -2.9934,   4.3165,  -5.2756,  ...,  -3.1058,   4.4168,   1.4156],
        [ -4.1945,   8.4496,   7.3791,  ...,  -4.2401,   0.7162,   4.2193],
        [ -8.4092,   9.2836,  -3.0875,  ..., -10.4867,   3.5389,  -1.4640],
        ...,
        [ 32.8595, -45.5166, -27.4276,  ...,  47.6876,  11.7916, -42.2014],
        [-12.8712, -13.3366, -29.3908,  ..., -28.2100, -17.8268,  -8.5858],
        [ -8.0902,  19.0255, -25.1759,  ...,  44.1477,  31.9002,   0.8230]],
       device='cuda:0')
tensor([[ -2.9934,   4.3165,  -5.2756,  ...,  -3.1058,   4.4168,   1.4156],
        [ -4.1945,   8.4496,   7.3791,  ...,  -4.2401,   0.7162,   4.2193],
        [ -8.4092,   9.2836,  -3.0875,  ..., -10.4867,   3.5389,  -1.4640],
        ...,
        [ 32.8595, -45.5166, -27.4276,  ...,  47.6876,  11.7916, -42.2014],
        [-12.8712, -13.3366, -29.3908,  ..., -28.2100, -17.8268,  -8.5858],
        [ -8.0902,  19.0255, -25.1759,  ...,  44.1477,  31.9002,   0.8230]],
       device='cuda:0')
Diff: tensor(0., device='cuda:0') 

The results are consistant under float.
With minor code changes, however, there is a big unacceptable difference in the final outputs under bfloat16 dtype.
Also, the results of bfloat16 can be the same if the inputs are restricted in a very small range, e.g., divided by 1024.

I guess the evil stems from the precision of bfloat16. but I can't figure out why tl.store brings such a big difference, and how to solve this question.
Could you give me some hints?

The environment is Triton 2.1 & A100-SXM4-40GB.

Thanks.

Mine is normal. NVIDIA A100 80GB PCIe, Triton nightly release

Testing BFloat16...
tensor([[ -3.0000,   4.3125,  -5.2812,  ...,  -3.1094,   4.4062,   1.4141],
        [ -4.1875,   8.4375,   7.3750,  ...,  -4.2188,   0.7227,   4.2188],
        [ -8.3750,   9.2500,  -3.0938,  ..., -10.4375,   3.5312,  -1.4688],
        ...,
        [ 33.0000, -45.5000, -27.5000,  ...,  47.5000,  11.7500, -42.5000],
        [-12.8125, -13.3125, -29.5000,  ..., -28.2500, -17.8750,  -8.5625],
        [ -8.0625,  19.0000, -25.1250,  ...,  44.0000,  31.8750,   0.7148]],
       device='cuda:0', dtype=torch.bfloat16)
tensor([[ -3.0000,   4.3125,  -5.2812,  ...,  -3.1094,   4.4062,   1.4141],
        [ -4.1875,   8.4375,   7.3750,  ...,  -4.2188,   0.7227,   4.2188],
        [ -8.3750,   9.2500,  -3.0938,  ..., -10.4375,   3.5312,  -1.4688],
        ...,
        [ 33.0000, -45.5000, -27.5000,  ...,  47.5000,  11.7500, -42.5000],
        [-12.8125, -13.3125, -29.5000,  ..., -28.2500, -17.8750,  -8.5625],
        [ -8.0625,  19.0000, -25.1250,  ...,  44.0000,  31.8750,   0.7148]],
       device='cuda:0', dtype=torch.bfloat16)
Diff: tensor(0., device='cuda:0', dtype=torch.bfloat16) 


Testing Float...
tensor([[ -2.9934,   4.3165,  -5.2756,  ...,  -3.1058,   4.4168,   1.4156],
        [ -4.1945,   8.4496,   7.3791,  ...,  -4.2401,   0.7162,   4.2193],
        [ -8.4092,   9.2836,  -3.0875,  ..., -10.4867,   3.5389,  -1.4640],
        ...,
        [ 32.8595, -45.5166, -27.4276,  ...,  47.6876,  11.7916, -42.2014],
        [-12.8712, -13.3366, -29.3908,  ..., -28.2100, -17.8268,  -8.5858],
        [ -8.0902,  19.0255, -25.1759,  ...,  44.1477,  31.9002,   0.8230]],
       device='cuda:0')
tensor([[ -2.9934,   4.3165,  -5.2756,  ...,  -3.1058,   4.4168,   1.4156],
        [ -4.1945,   8.4496,   7.3791,  ...,  -4.2401,   0.7162,   4.2193],
        [ -8.4092,   9.2836,  -3.0875,  ..., -10.4867,   3.5389,  -1.4640],
        ...,
        [ 32.8595, -45.5166, -27.4276,  ...,  47.6876,  11.7916, -42.2014],
        [-12.8712, -13.3366, -29.3908,  ..., -28.2100, -17.8268,  -8.5858],
        [ -8.0902,  19.0255, -25.1759,  ...,  44.1477,  31.9002,   0.8230]],
       device='cuda:0')
Diff: tensor(0., device='cuda:0') 

@Jokeren Hi, we @sustcsonglin found when using Triton 2.2 & H100 (A100 works), running this check will still give very strange results.

>>> python tests/test_fused_chunk.py
DTYPE           STORE   IFCOND  DIFF
torch.float32   False   False   0.0
torch.float32   False   True    0.0
torch.float32   True    False   0.0
torch.float32   True    True    0.0
torch.bfloat16  False   False   218.81393432617188
torch.bfloat16  False   True    0.6739959716796875
torch.bfloat16  True    False   218.81393432617188
torch.bfloat16  True    True    0.6739959716796875

Can you figure out what happens?

The bug can be bypassed if adding some cond check at the for loop beginnings, like the above.

Can you try with triton/main?