HazyResearch/ThunderKittens

[bug report] h100 attn_causal kernel

xiayuqing0622 opened this issue · 3 comments

Using the same random seed, the result of tk h100 attn_causal kernel vary with each run. In some cases, the max diff between tk and pytorch result can be larger than 2.

It turns out that a wgmma fence needs to be added after wgmma async wait. I have created a PR for your reference. #34

Thanks for sharing your PR @xiayuqing0622.

There should be a wgmma async_wait after the wgmma is committed. As you know, the wgmma API launches an async matrix multiply on the H100 tensor cores across the 4 warps in the warpgroup via the commit_group function exposed. In order to ensure this is completed, you need to call and async_wait on it - this was missing from the original code and has now been added to the the relevant .cu file in the main branch.

Re your PR, the wgmma fence/syncthreads() should not be necessary in order to achieve correctness. The fence is needed when on the output register tile before you launch the wgmma async instruction. Furthermore, the syncthreads() will likely unnecessarily slow performance.

Does the latest fix on main fix the randomness you were seeing?

Thanks for sharing your PR @xiayuqing0622.

There should be a wgmma async_wait after the wgmma is committed. As you know, the wgmma API launches an async matrix multiply on the H100 tensor cores across the 4 warps in the warpgroup via the commit_group function exposed. In order to ensure this is completed, you need to call and async_wait on it - this was missing from the original code and has now been added to the the relevant .cu file in the main branch.

Re your PR, the wgmma fence/syncthreads() should not be necessary in order to achieve correctness. The fence is needed when on the output register tile before you launch the wgmma async instruction. Furthermore, the syncthreads() will likely unnecessarily slow performance.

Does the latest fix on main fix the randomness you were seeing?

@Aaryan0404 Thanks for your reply. Actually, after reading the document of cuda, I also believe the wgmma fence/syncthreads() is not necessary. However, just adding async wait does not fix the randomness (I just tested it on the latest main branch). I don't know why. Here is my test script (just add a debug function in h100_fwd_check.py):

import torch 
import sys
import os
import time

current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "../../"))
sys.path.insert(0, project_root)
from src.common.pyutils.test_build_utils import __eq
sys.path.append('build/lib.linux-x86_64-cpython-312')
import h100_fwd as mod

from collections import defaultdict
import matplotlib.pyplot as plt
from statistics import median
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
torch.manual_seed(0)

def debug(name,expect, actual, atol=1e-3, rtol=1e-3):
    all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol)
    print(name + "  all_close={}".format(all_close))
    if not all_close:
        diff = (expect - actual).abs()
        print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item()))
        max_indices  = torch.nonzero(diff == diff.max().item())
        first_index = tuple(max_indices[0].tolist())
        print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") 


def pytorch_test(Q, K, V):
    output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
    return output

def h100_fwd_kernel_test(Q, K, V):
    o = torch.zeros_like(Q)
    mod.attention_forward_causal(Q, K, V, o)
    return o

def check_correctness(b, h, n, d):
    print(f"Testing with b={b}, h={h}, n={n}, d={d}")
    
    Q = torch.randn(b, h, n, d, dtype=torch.bfloat16, device='cuda').contiguous()
    K = torch.randn(b, h, n, d, dtype=torch.bfloat16, device='cuda').contiguous()
    V = torch.randn(b, h, n, d, dtype=torch.bfloat16, device='cuda').contiguous()
    
    result_pytorch = pytorch_test(Q, K, V)
    tk_result = h100_fwd_kernel_test(Q, K, V)
    
    diff = result_pytorch - tk_result
    avg_diff_mag = torch.mean(torch.abs(diff)).item()
    avg_diff_per = 100 * avg_diff_mag / torch.mean(torch.abs(result_pytorch)).item()
    
    print(f"Attention output - avg magnitude of diff: {avg_diff_mag:.6f}")
    print("-" * 40)
    debug("Attention output", result_pytorch, tk_result)

print("Correctness Tests: ")
configurations = [
    # (2,  8, 256,   64),
    # (4,  8, 512,   64),
    # (8,  8, 1024,  64),
    # (16, 8, 2048,  64),
    # (16, 8, 4096,  64),
    # (16, 8, 8192,  64),
    # (16, 8, 16384, 64),
    # (2,  8, 256,   128),
    # (4,  8, 512,   128),
    # (8,  8, 1024,  128),
    # (16, 8, 2048,  128),
    # (16, 8, 4096,  128),
    (16, 8, 8192,  128),
    # (16, 8, 16384, 128)
]
for b, h, n, d in configurations:
    check_correctness(b, h, n, d)