Big results difference when using `tl.store`
yzhangcs opened this issue · 3 comments
Hi, I find a big results difference when using tl.store (under bfloat16).
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
@triton.jit
def attention_nostore_fwd_kernel(
q,
k,
v,
h,
o,
s_qh,
s_qt,
s_qd,
s_hh,
s_ht,
H,
T,
TD,
scale,
BT: tl.constexpr,
BD: tl.constexpr
):
i_bh = tl.program_id(0)
# [BD, BD]
b_h = tl.zeros([BD, BD], dtype=tl.float32)
for i in range(0, tl.cdiv(T, BT)):
p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
# [BT, BD]
b_q = tl.load(p_q)
b_q = (b_q * scale).to(b_q.dtype)
# [BD, BT]
b_k = tl.load(p_k)
# [BT, BD]
b_v = tl.load(p_v)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
# [BT, BD]
b_o = tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
# tl.store(p_h, b_h.to(p_h.dtype.element_ty))
tl.store(p_o, b_o.to(p_o.dtype.element_ty))
# [BD, BD]
b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
@triton.jit
def attention_store_fwd_kernel(
q,
k,
v,
h,
o,
s_qh,
s_qt,
s_qd,
s_hh,
s_ht,
H,
T,
TD,
scale,
BT: tl.constexpr,
BD: tl.constexpr
):
i_bh = tl.program_id(0)
# [BD, BD]
b_h = tl.zeros([BD, BD], dtype=tl.float32)
for i in range(0, tl.cdiv(T, BT)):
p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
# [BT, BD]
b_q = tl.load(p_q)
b_q = (b_q * scale).to(b_q.dtype)
# [BD, BT]
b_k = tl.load(p_k)
# [BT, BD]
b_v = tl.load(p_v)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
# [BT, BD]
b_o = tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
tl.store(p_h, b_h.to(p_h.dtype.element_ty))
tl.store(p_o, b_o.to(p_o.dtype.element_ty))
# [BD, BD]
b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
class AttentionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, store=False):
batch_size, n_heads, seq_len, d_head = q.shape
scale = d_head ** -0.5
BD = q.shape[-1]
BT = 32
num_stages = 3 if d_head <= 64 else 2
num_warps = 4
h = q.new_empty(batch_size, n_heads, triton.cdiv(seq_len, BT) * BD, BD)
o = torch.empty_like(q)
grid = (batch_size * n_heads,)
kernel = attention_store_fwd_kernel if store else attention_nostore_fwd_kernel
kernel[grid](
q, k, v, h, o,
q.stride(1), q.stride(2), q.stride(3), h.stride(1), h.stride(2),
n_heads, seq_len, h.shape[2], scale,
BT=BT, BD=BD,
num_warps=num_warps,
num_stages=num_stages
)
return o
if __name__ == '__main__':
B, H, T, D = 2, 8, 1024, 128
dtype = torch.bfloat16
torch.manual_seed(42)
# [batch_size, n_heads, seq_len, d_head]
q = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
k = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
v = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
print('Testing BFloat16...')
ref = AttentionFunction.apply(q, k, v, True)
tri = AttentionFunction.apply(q, k, v, False)
print(ref[0, 0])
print(tri[0, 0])
print('Diff:', (ref - tri).abs().max(), '\n\n')
print('Testing Float...')
q, k, v = q.float(), k.float(), v.float()
ref = AttentionFunction.apply(q, k, v, True)
tri = AttentionFunction.apply(q, k, v, False)
print(ref[0, 0])
print(tri[0, 0])
print('Diff:', (ref - tri).abs().max(), '\n\n')
I hve pasted the tailored code here for ease of reproduction.
The only differnce between attention_nostore_fwd_kernel
and attention_store_fwd_kernel
is tl.store(p_h, b_h.to(p_h.dtype.element_ty))
, which saves the intermediate results to HBMs, and the output is
Testing BFloat16...
tensor([[ -3.0000, 4.3125, -5.2812, ..., -3.1094, 4.4062, 1.4141],
[ -4.1875, 8.4375, 7.3750, ..., -4.2188, 0.7227, 4.2188],
[ -8.3750, 9.2500, -3.0938, ..., -10.4375, 3.5312, -1.4688],
...,
[ 33.0000, -45.5000, -27.5000, ..., 47.5000, 11.7500, -42.5000],
[-12.8125, -13.3125, -29.5000, ..., -28.2500, -17.8750, -8.5625],
[ -8.0625, 19.0000, -25.1250, ..., 44.0000, 31.8750, 0.7148]],
device='cuda:0', dtype=torch.bfloat16)
tensor([[ -3.0000, 4.3125, -5.2812, ..., -3.1094, 4.4062, 1.4141],
[ -4.1875, 8.4375, 7.3750, ..., -4.2188, 0.7227, 4.2188],
[ -8.3750, 9.2500, -3.0938, ..., -10.4375, 3.5312, -1.4688],
...,
[ 21.0000, -57.2500, 94.0000, ..., -6.8125, -43.5000, -3.3281],
[ 91.0000, 29.6250, 0.9414, ..., 15.3750, -4.5000, 13.4375],
[ 8.0625, -24.6250, 21.8750, ..., 1.3672, -21.3750, 96.0000]],
device='cuda:0', dtype=torch.bfloat16)
Diff: tensor(223., device='cuda:0', dtype=torch.bfloat16)
Testing Float...
tensor([[ -2.9934, 4.3165, -5.2756, ..., -3.1058, 4.4168, 1.4156],
[ -4.1945, 8.4496, 7.3791, ..., -4.2401, 0.7162, 4.2193],
[ -8.4092, 9.2836, -3.0875, ..., -10.4867, 3.5389, -1.4640],
...,
[ 32.8595, -45.5166, -27.4276, ..., 47.6876, 11.7916, -42.2014],
[-12.8712, -13.3366, -29.3908, ..., -28.2100, -17.8268, -8.5858],
[ -8.0902, 19.0255, -25.1759, ..., 44.1477, 31.9002, 0.8230]],
device='cuda:0')
tensor([[ -2.9934, 4.3165, -5.2756, ..., -3.1058, 4.4168, 1.4156],
[ -4.1945, 8.4496, 7.3791, ..., -4.2401, 0.7162, 4.2193],
[ -8.4092, 9.2836, -3.0875, ..., -10.4867, 3.5389, -1.4640],
...,
[ 32.8595, -45.5166, -27.4276, ..., 47.6876, 11.7916, -42.2014],
[-12.8712, -13.3366, -29.3908, ..., -28.2100, -17.8268, -8.5858],
[ -8.0902, 19.0255, -25.1759, ..., 44.1477, 31.9002, 0.8230]],
device='cuda:0')
Diff: tensor(0., device='cuda:0')
The results are consistant under float.
With minor code changes, however, there is a big unacceptable difference in the final outputs under bfloat16 dtype.
Also, the results of bfloat16 can be the same if the inputs are restricted in a very small range, e.g., divided by 1024.
I guess the evil stems from the precision of bfloat16. but I can't figure out why tl.store
brings such a big difference, and how to solve this question.
Could you give me some hints?
The environment is Triton 2.1 & A100-SXM4-40GB.
Thanks.
Mine is normal. NVIDIA A100 80GB PCIe, Triton nightly release
Testing BFloat16...
tensor([[ -3.0000, 4.3125, -5.2812, ..., -3.1094, 4.4062, 1.4141],
[ -4.1875, 8.4375, 7.3750, ..., -4.2188, 0.7227, 4.2188],
[ -8.3750, 9.2500, -3.0938, ..., -10.4375, 3.5312, -1.4688],
...,
[ 33.0000, -45.5000, -27.5000, ..., 47.5000, 11.7500, -42.5000],
[-12.8125, -13.3125, -29.5000, ..., -28.2500, -17.8750, -8.5625],
[ -8.0625, 19.0000, -25.1250, ..., 44.0000, 31.8750, 0.7148]],
device='cuda:0', dtype=torch.bfloat16)
tensor([[ -3.0000, 4.3125, -5.2812, ..., -3.1094, 4.4062, 1.4141],
[ -4.1875, 8.4375, 7.3750, ..., -4.2188, 0.7227, 4.2188],
[ -8.3750, 9.2500, -3.0938, ..., -10.4375, 3.5312, -1.4688],
...,
[ 33.0000, -45.5000, -27.5000, ..., 47.5000, 11.7500, -42.5000],
[-12.8125, -13.3125, -29.5000, ..., -28.2500, -17.8750, -8.5625],
[ -8.0625, 19.0000, -25.1250, ..., 44.0000, 31.8750, 0.7148]],
device='cuda:0', dtype=torch.bfloat16)
Diff: tensor(0., device='cuda:0', dtype=torch.bfloat16)
Testing Float...
tensor([[ -2.9934, 4.3165, -5.2756, ..., -3.1058, 4.4168, 1.4156],
[ -4.1945, 8.4496, 7.3791, ..., -4.2401, 0.7162, 4.2193],
[ -8.4092, 9.2836, -3.0875, ..., -10.4867, 3.5389, -1.4640],
...,
[ 32.8595, -45.5166, -27.4276, ..., 47.6876, 11.7916, -42.2014],
[-12.8712, -13.3366, -29.3908, ..., -28.2100, -17.8268, -8.5858],
[ -8.0902, 19.0255, -25.1759, ..., 44.1477, 31.9002, 0.8230]],
device='cuda:0')
tensor([[ -2.9934, 4.3165, -5.2756, ..., -3.1058, 4.4168, 1.4156],
[ -4.1945, 8.4496, 7.3791, ..., -4.2401, 0.7162, 4.2193],
[ -8.4092, 9.2836, -3.0875, ..., -10.4867, 3.5389, -1.4640],
...,
[ 32.8595, -45.5166, -27.4276, ..., 47.6876, 11.7916, -42.2014],
[-12.8712, -13.3366, -29.3908, ..., -28.2100, -17.8268, -8.5858],
[ -8.0902, 19.0255, -25.1759, ..., 44.1477, 31.9002, 0.8230]],
device='cuda:0')
Diff: tensor(0., device='cuda:0')
@Jokeren Hi, we @sustcsonglin found when using Triton 2.2 & H100 (A100 works), running this check will still give very strange results.
>>> python tests/test_fused_chunk.py
DTYPE STORE IFCOND DIFF
torch.float32 False False 0.0
torch.float32 False True 0.0
torch.float32 True False 0.0
torch.float32 True True 0.0
torch.bfloat16 False False 218.81393432617188
torch.bfloat16 False True 0.6739959716796875
torch.bfloat16 True False 218.81393432617188
torch.bfloat16 True True 0.6739959716796875
Can you figure out what happens?
The bug can be bypassed if adding some cond check at the for loop beginnings, like the above.
Can you try with triton/main?