hidet-org/hidet

[Enquiry] developing Flash Attention Transformer example using Hidet

keneoneth opened this issue · 3 comments

Hello guys, really appreciate your work on Hidet. It is an awesome tool and it really makes developer's life easier when writing custom schedule for their CUDA kernel for performance optimization👍👍!

To test on Hidet's features, I am currently writing an example of the Flash Attention Transformer (link to research work: https://arxiv.org/abs/2205.14135) using the Hidet tool stack. I have writteb my custom testing setup (which contains my own host/device memory allocation & performance tracking & precision comparison code) in my "flash_attention_main.cu", and I am trying to call the kernel functions in Hidet generated cuda dynamic library.

May I know if there is a standard way of doing this? I tried using "dlopen" to load the library and launch the kernel functions but unfortunately it is not working properly. I therefore just manually copied the Hidet generated cuda source code to two separate header files "flash_attention_kernel_func.h" and "normal_transformer_kernel_func.h" and include them in my "flash_attention_main.cu". And I directly compile "flash_attention_main.cu" and everything works properly as well.

Let me share some source code below for illustration.

Here is my flash_attention_example.py, which includes the flash attention custom schedule and the normal approach.

import os
import math
import time
import numpy as np
import torch
import torch.nn as nn
torch.manual_seed(123)

# NOTE: this script is a simplified implementation of the following research work using Hidet
# Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35, 16344-16359.
# link to paper: https://arxiv.org/abs/2205.14135

import hidet
from hidet.ir.compute import compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule
from hidet.ir.primitives.cuda.atomic import atomic_add
from hidet.lang import f16, spatial, repeat, tensor, attr, grid, printf
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.graph.ops.definitions.utils import input_like
from hidet.ir.expr import cast, address
from hidet.ir.primitives import exp, max, printf

# define Flash Attention Task
class FlashAttentionTask(Task):

    def allow_epilogue(self) -> bool:
        return False

    def flash_attention_implement_cuda(self, working_dir: str) -> IRModule:
        # override this method to use template-based scheduling
        return flash_attention_schedule(self)
    
    # Require: Matrices Q�K�V Nxd in HBM, on-chip SRAM of size M.
    # NOTE: typical SRAM size 100 kB, default to 48 kB
    # NOTE: max thread num is set to 1024
    def __init__(self,N=512,d=128,H=16,B=1,M=48*1024,ratio=12,max_thread_num=1024,disable_flash_attention=False):

        # 1. set block sizes Bc = ceil(M/(4d)), Br = min(M/(4d),d)
        Bc = math.ceil(M/(ratio*d))
        Br = min(math.ceil(M/(ratio*d)),d)
        Tr = math.ceil(N/Br)
        Tc = math.ceil(N/Bc)
        GLOBAL_Q = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_Q')
        GLOBAL_K = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_K')
        GLOBAL_V = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_V')
        
        def normal_transformer():
            matmulQK = compute(
                    name = 'GLOBAL_QK',
                    shape = [N, N],
                    fcompute = lambda i, j: reduce(
                        shape=[d],
                        fcompute=lambda k: GLOBAL_Q[i, k] * GLOBAL_K[j, k],
                        reduce_type='sum',
                    )
                )

            max_val = lambda i : reduce(shape=[N], fcompute=lambda j: matmulQK[i,j], reduce_type='max')
            S = compute(
                    name = 'S',
                    shape = [N, N],
                    fcompute = lambda i,j: matmulQK[i,j] - max_val(i)
                )
            exp_s = compute(
                    name = 'exp_s',
                    shape = [N, N],
                    fcompute = lambda i,j: exp(S[i,j])
                )
            exp_sum = lambda i : reduce(shape=[N], fcompute=lambda j: exp_s[i,j], reduce_type='sum')
            softmax = compute('softmax', shape=[N,N], fcompute=lambda i,j: exp_s[i,j] / exp_sum(i))
            matmulPV = compute(
                    name = 'GLOBAL_O',
                    shape = [N, d],
                    fcompute = lambda i, j: reduce(
                        shape=[N],
                        fcompute=lambda k: softmax[i, k] * GLOBAL_V[k, j],
                        reduce_type='sum',
                    )
                )
            return matmulPV
        
        super().__init__(
            name='flash_attention_task',
            inputs=[GLOBAL_Q,GLOBAL_K,GLOBAL_V],
            outputs=[normal_transformer()],
            attributes={
                'B' : B,
                'H' : H,
                'N' : N,
                'd' : d,
                'Bc' : Bc,
                'Br' : Br,
                'Tc' : Tc,
                'Tr' : Tr,
                'BLK' : Tr,
                'THD' : Br * Bc,
                'MAX_THD' : max_thread_num
            },
        )
        if not disable_flash_attention:
            self.implement_cuda = self.flash_attention_implement_cuda
            self.define = "-DRUN_FLASH_ATTN"
        else:
            self.define = ""

# define custom schedule
def flash_attention_schedule(task:FlashAttentionTask) -> IRModule:
    
    print_debug = False

    B = task.attrs['B']
    H = task.attrs['H']
    N = task.attrs['N']
    d = task.attrs['d']
    Bc = task.attrs['Bc']
    Br = task.attrs['Br']
    Tr = task.attrs['Tr']
    Tc = task.attrs['Tc']

    dims = ( task.attrs['BLK'] )
    threads = task.attrs['THD']
    assert threads <= task.attrs['MAX_THD'], f'err: {threads} not < {task.attrs["MAX_THD"]}'
    assert d % Bc == 0, f'err: Bc is not divisible by d'
    assert d % Br == 0, f'err: Br is not divisible by d'


    largest_fp16_value = 65504

    print(f'task.attrs {task.attrs}')
    
    
    # define the tensor program
    with hidet.script_module() as module:
        """Flash attention kernel."""

        @hidet.script
        def QK_matmul_compute(A:f16[Br,d],B:f16[d,Bc],C:f16[Br,Bc]):
            for m,n in spatial(Br,Bc).on(threadIdx.x):
                C[m,n] = 0.0
            syncthreads()
            for m,k,n in spatial(Br,1,Bc).repeat(1,d,1).on(threadIdx.x):   
                atomic_add(~C[m,n],A[m,k] * B[k,n])
            syncthreads()

        @hidet.script
        def PV_matmul_compute(A:f16[Br,Bc],B:f16[Bc,d],C:f16[Br,d]):
            for m,n in spatial(Br,Bc).repeat(1,d//Bc).on(threadIdx.x):
                C[m,n] = 0.0
            syncthreads()
            for m,k,n in spatial(Br,1,Bc).repeat(1,Bc,d//Bc).on(threadIdx.x):   
                atomic_add(~C[m,n],A[m,k] * B[k,n])
            syncthreads()

        @hidet.script
        def rowmax_compute(A:f16[Br,Bc],M:f16[Br],T:f16[Br,Bc]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                T.write([i,j],A[i,j],protected=True)
            syncthreads()

            for i,j in spatial(Br,Bc).on(threadIdx.x):
                k = 1
                while k < Bc:
                    if j % (k*2) == 0:
                        T.write([i,j],max(T[i,j],T[i,j+k]),protected=True)
                    syncthreads()
                    k *= 2

            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    M[i] = T[i,0]
            syncthreads()

        @hidet.script
        def rowsum_compute(A:f16[Br,Bc],L:f16[Br],T:f16[Br,Bc]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                T.write([i,j],A[i,j],protected=True)
            syncthreads()

            for i,j in spatial(Br,Bc).on(threadIdx.x):
                k = 1
                while k < Bc:
                    if j % (k*2) == 0:
                        T.write([i,j],(T[i,j]+T[i,j+k]),protected=True)
                    syncthreads()
                    k *= 2

            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    L[i] = T[i,0]
            syncthreads()

        @hidet.script
        def local_softmax_compute(S:f16[Br,Bc],M:f16[Br]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                if False and blockIdx.x==0:
                    printf("S[i,j] before %d %d %d %d : %f - %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"),cast(M[i],"float32"))
                S[i,j] = exp(S[i,j] - M[i])
                if False and blockIdx.x==0:
                    printf("S[i,j] %d %d %d %d : %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"))
            syncthreads()
        
        @hidet.script
        def local_update_compute(M:f16[Br],M_new:f16[Br],M_local:f16[Br],L:f16[Br],L_new:f16[Br],L_local:f16[Br]):
            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    M_new[i] = max(M[i],M_local[i])
                    L_new[i] = exp(M[i] - M_new[i]) * L[i] + exp(M_local[i] - M_new[i]) * L_local[i]
            syncthreads()

        @hidet.script
        def global_update_compute(PV:f16[Br,d],O:f16[Br,d],M_local:f16[Br],M_new:f16[Br],M:f16[Br],L_new:f16[Br],L:f16[Br]):
            for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                O.write(
                    [i,j],
                    ((L_new[i]**-1) * (L[i]*exp(M[i]-M_new[i])) * O[i,j]) + (exp(M_local[i]-M_new[i]) * PV[i,j]),
                    protected=True
                )
            syncthreads()

        @hidet.script
        def flash_attention_kernel(
            Q: f16[N,d],
            K: f16[N,d],
            V: f16[N,d],
            O: f16[N,d]
        ):
            
            attr.cuda_grid_dim = dims
            attr.cuda_block_dim = threads

            # Init O=(0), N x d in HBM
            for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                offset_i = blockIdx.x * (Br)
                O[offset_i:,:].write([i,j], 0, protected=True)
            syncthreads()

            smem_q = tensor('shared', 'float16', [Br, d])
            smem_k = tensor('shared', 'float16', [d, Bc]) # transposed
            smem_v = tensor('shared', 'float16', [Bc, d])
            smem_o = tensor('shared', 'float16', [Br, d])
            
            smem_l = tensor('shared', 'float16', [Br])
            smem_l_local = tensor('shared', 'float16', [Br])
            smem_l_new = tensor('shared', 'float16', [Br])
            smem_m = tensor('shared', 'float16', [Br])
            smem_m_local = tensor('shared', 'float16', [Br])
            smem_m_new = tensor('shared', 'float16', [Br])
            smem_sp = tensor('shared', 'float16', [Br,Bc])
            smem_pv = tensor('shared', 'float16', [Br,d])
            smem_temp = tensor('shared', 'float16', [Br,Bc])

            for a,b in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                # load Qi from HBM to on-chip SRAM
                # initialization of o,l,m
                offset_i = blockIdx.x * (Br)
                smem_q[a,b] = Q[offset_i:,:].read([a,b],protected=True)
                smem_o[a,b] = 0
                smem_l[a] = 0
                smem_m[a] = -largest_fp16_value
            syncthreads()

            if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                idx = 0
                for i,j in grid(Br,d):
                    printf("idx: %d, Q val: %f\n",idx,cast(smem_q[i,j],"float32"))
                    idx += 1
            syncthreads()

            for j in grid(Tc):

                for a,b in spatial(Bc,Br).repeat(1,(d//Br)).on(threadIdx.x):
                    # load Kj,Vj from HBM to on-chip SRAM
                    offset_j = j * (Bc)
                    smem_k[b,a] = K[offset_j:,:].read([a,b],protected=True)
                    smem_v[a,b] = V[offset_j:,:].read([a,b],protected=True)
                syncthreads()
                
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(d,Bc):
                        printf("idx: %d, K val: %f\n",idx,cast(smem_k[i,j],"float32"))
                        idx += 1
                    for i,j in grid(Bc,d):
                        printf("idx: %d, V val: %f\n",idx,cast(smem_v[i,j],"float32"))
                        idx += 1
                syncthreads()
                
                # on chip, compute Sij = Qi @ (Kj)^T, Br X Bc
                QK_matmul_compute(smem_q,smem_k,smem_sp)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,Bc):
                        printf("idx: %d, S val: %f\n",idx,cast(smem_sp[i,j],"float32"))
                        idx += 1
                syncthreads()
                
                # on chip, compute m'_ij = rowmax(Sij), Br; Pij = exp(Sij - m'_ij), Br x Bc (pointwise); l'_ij = rowsum(P'ij), Br
                rowmax_compute(smem_sp,smem_m_local,smem_temp)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d M val: %f\n",i,cast(smem_m_local[i],"float32"))
                        # for j in grid(Bc):
                        #     printf("j: %d, S val: %f\n",j,cast(smem_sp[i,j],"float32"))
                syncthreads()
                
                local_softmax_compute(smem_sp,smem_m_local)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,Bc):
                        printf("idx: %d, P val: %f\n",idx,cast(smem_sp[i,j],"float32"))
                        idx += 1
                syncthreads()


                rowsum_compute(smem_sp,smem_l_local,smem_temp)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d L val: %f\n",i,cast(smem_l_local[i],"float32"))
                        # for j in grid(Bc):
                        #     printf("j: %d, P val: %f\n",j,cast(smem_sp[i,j],"float32"))
                syncthreads()

                
                # on chip, compute m_new_i = max(m_i,m'_ij), Br; l_new_i = e^(m_i - m_new_i) * l_i + e^(m'_ij - m_i_new) * l'_ij, Br
                local_update_compute(smem_m,smem_m_new,smem_m_local,smem_l,smem_l_new,smem_l_local)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d smem_m val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_m_new val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_m_local val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l_new val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l_local val: %f\n",i,cast(smem_m[i],"float32"))
                syncthreads()
                # write Oi = diag(l_i_new)^-1 * (diag(l_i)*e^(m_i-m_i_new) @ Oi + e^*m'_ij-m_i_new*(P'ij @ Vj))

                PV_matmul_compute(smem_sp,smem_v,smem_pv)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,d):
                        printf("idx: %d, PV val: %f\n",idx,cast(smem_pv[i,j],"float32"))
                        idx += 1
                syncthreads()

                global_update_compute(smem_pv,smem_o,smem_m_local,smem_m_new,smem_m,smem_l_new,smem_l)

                if j + 1 == Tc:
                    for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                        offset_i = blockIdx.x * (Br)
                        O[offset_i:,:].write([i,j], smem_o[i,j], protected=True)
                    syncthreads()

                # write l_i = l_i_new, m_i = m_i_new
                for i in spatial(Br).on(threadIdx.x):
                    if threadIdx.x < Br:
                        smem_m[i] = smem_m_new[i]
                        smem_l[i] = smem_l_new[i]
                syncthreads()

            if print_debug and (blockIdx.x==15 and threadIdx.x==0):
                idx = 0
                for i,j in grid(Br,d):
                    offset_i = blockIdx.x * (Br)
                    printf("blockIdx %d : output idx: %d, val: %f\n",blockIdx.x,idx,cast(O[offset_i+i,j],"float32"))
                    idx += 1
            syncthreads()
            return

        @hidet.script
        def flash_attention_launch_func( 
            G_Q: f16[B, H, N, d],
            G_K: f16[B, H, N, d],
            G_V: f16[B, H, N, d],
            G_O: f16[B, H, N, d]
        ):
            # NOTE: this section needs to be written in flash_attention_main.cu
            for b,h in grid(B,H):
                flash_attention_kernel(
                    address(G_Q[b,h,0,0]),
                    address(G_K[b,h,0,0]),
                    address(G_V[b,h,0,0]),
                    address(G_O[b,h,0,0])
                )
            
    # build ir module
    ir_module = module.ir_module()
    return ir_module

# gen Python gold data as reference
def gen_gold(attrs,r1=-3,r2=3):

    Q = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    K = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    V = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    t = time.process_time()
    Q.half().numpy().tofile('mat_Q.bin')
    K.half().numpy().tofile('mat_K.bin')
    V.half().numpy().tofile('mat_V.bin')
    S = torch.from_numpy(Q.numpy() @ torch.transpose(K, -2, -1).numpy())

    row_max, _ = torch.max(S,dim=-1)
    S = torch.from_numpy(np.exp((S - row_max.reshape(attrs['B'],attrs['H'],attrs['N'],1)).numpy()))
    row_sum = torch.sum(S,dim=-1).reshape(attrs['B'],attrs['H'],attrs['N'],1)
    P = S / row_sum

    # TODO: test with softmax float precision
    # P = nn.Softmax(dim=-1)(S.float()).half()
    O = torch.from_numpy(P.numpy() @ V.numpy())
    elapsed_time = (time.process_time() - t)*1000
    print(f"Python gold gen run elapsed time {round(elapsed_time,3)} msec")
    O.half().numpy().tofile('gold_mat_O.bin')

# run task
def run_task(disable_flash_attention=False):
    # define the task here
    flash_attention_task = FlashAttentionTask(disable_flash_attention=disable_flash_attention)
    # build the task
    ret = flash_attention_task.build(target='cuda')

    # copy source file and lib to current directory
    source_path = ret.src_path
    library_path = ret.lib_path
    print(f'source_path {source_path} library_path {library_path}')

    import shutil
    shutil.copyfile(source_path,os.path.join("./","flash_attention_"+os.path.basename(source_path)))
    shutil.copyfile(library_path,os.path.join("./","flash_attention_"+os.path.basename(library_path)))

    # generate golden data
    gen_gold(flash_attention_task.attrs)

    def exe_f(command='', shell=True):
        print(f'running {command}')
        import subprocess
        process = subprocess.Popen(command, shell=shell)
        code = process.wait()
        process.communicate()
        return code
    
    # launch testcase flash_attention_main.cu
    HIDET_CUDA_INCLUDE_PATH = "../cuda-samples-master/Common/"
    CUDA_SAMPLES_INCLUDE_PATH = "../../include/"
    ret = exe_f(f'nvcc flash_attention_main.cu {flash_attention_task.define} -gencode arch=compute_86,code=sm_86 -I {CUDA_SAMPLES_INCLUDE_PATH} -I {HIDET_CUDA_INCLUDE_PATH} -std=c++11 -o fa.out && ./fa.out -BATCH={flash_attention_task.attrs["B"]} -HEAD={flash_attention_task.attrs["H"]} -BLK={flash_attention_task.attrs["BLK"]} -THD={flash_attention_task.attrs["THD"]}')
    print('test done' if ret==0 else 'test error')

# main function
if __name__ == '__main__':
    # normal approach execution
    run_task(disable_flash_attention=True)
    # flash attention approach execution
    run_task(disable_flash_attention=False)

Here is my flash_attention_main.cu, which includes the performance tracking, precision comparison & memory allocation operations, and it lauches the test kernels.

// System includes
#include <stdio.h>
#include <sys/stat.h>
#include <dlfcn.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <assert.h>
#include <vector>

// CUDA runtime
#include <cuda_runtime.h>
#include <cuda_profiler_api.h>

// Helper functions and utilities to work with CUDA,
#include <helper_functions.h>
#include <helper_cuda.h>
#include <cuda_fp16.h>

// Import kernel functions
#include "flash_attention_kernel_func.h"
#include "normal_transformer_kernel_func.h"


// test function, execute kernel, compare with gold data
int flash_attention_test(
    unsigned int B, unsigned int H,
    unsigned int block_size, unsigned int thread_size,
    half *h_Q, unsigned int size_Q,
    half *h_K, unsigned int size_K,
    half *h_V, unsigned int size_V,
    half *h_gold_O, unsigned int size_O)
{

    cudaStream_t stream;
    const unsigned int BH = B * H;
    // Allocate device memory
    half *d_Q, *d_K, *d_V, *d_O, *h_O;
    checkCudaErrors(cudaMallocHost(&h_O, size_O * sizeof(half)));

    if (h_O == NULL)
    {
        fprintf(stderr, "Failed to allocate host matrix O!\n");
        exit(EXIT_FAILURE);
    }

    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_Q), size_Q * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_K), size_K * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_V), size_V * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_O), size_O * sizeof(half)));
    // Allocate CUDA events that we'll use for timing
    cudaEvent_t start, stop;
    checkCudaErrors(cudaEventCreate(&start));
    checkCudaErrors(cudaEventCreate(&stop));

    checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));

    // copy host memory to device
    checkCudaErrors(
        cudaMemcpyAsync(d_Q, h_Q, size_Q * sizeof(half), cudaMemcpyHostToDevice, stream));
    checkCudaErrors(
        cudaMemcpyAsync(d_K, h_K, size_K * sizeof(half), cudaMemcpyHostToDevice, stream));
    checkCudaErrors(
        cudaMemcpyAsync(d_V, h_V, size_V * sizeof(half), cudaMemcpyHostToDevice, stream));

    const unsigned int k_size_Q = (size_Q / BH);
    const unsigned int k_size_K = (size_K / BH);
    const unsigned int k_size_V = (size_V / BH);
    const unsigned int k_size_O = (size_O / BH);
    printf("k_size_Q %u k_size_K %u k_size_V %u k_size_O %u\n", k_size_Q, k_size_K, k_size_V, k_size_O);

    // Record the start event
    checkCudaErrors(cudaEventRecord(start, stream));

    const int32_t num_args = 4;

    for (unsigned int b = 0; b < B; b++)
    {
        for (unsigned int h = 0; h < H; h++)
        {
            unsigned int offset_index = (b * H) + h;

            half *param[num_args] = {
                d_Q + offset_index * k_size_Q,
                d_K + offset_index * k_size_K,
                d_V + offset_index * k_size_V,
                d_O + offset_index * k_size_O};

#ifdef RUN_FLASH_ATTN
            // run flash attention kernel
            flash_attention_kernel<<<dim3(16, 1, 1), dim3(1024, 1, 1), 0, (cudaStream_t)stream>>>(((half *)(param[0])), ((half *)(param[1])), ((half *)(param[2])), ((half *)(param[3])));
#else
            // run normal transformer kernel
            uint8_t *buffer;

            checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&buffer), int64_t(2097152ll)));

            half *GLOBAL_QK = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(0ll) * ((int64_t)(1))))]));
            half *S = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(524288ll) * ((int64_t)(1))))]));
            half *exp_s = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(1048576ll) * ((int64_t)(1))))]));
            half *softmax = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(1572864ll) * ((int64_t)(1))))]));

            hidet_compute_GLOBAL_QK<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(param[0], param[1], GLOBAL_QK);
            hidet_compute_S<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(GLOBAL_QK, param[0], param[1], S);
            hidet_compute_exp_s<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(S, exp_s);
            hidet_compute_softmax<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(S, param[0], param[1], GLOBAL_QK, exp_s, softmax);
            hidet_compute_GLOBAL_O<<<dim3(128, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(softmax, ((half *)(param[2])), param[0], param[1], GLOBAL_QK, S, exp_s, ((half *)(param[3])));
#endif // RUN_FLASH_ATTN

        }
    }

    checkCudaErrors(cudaStreamSynchronize(stream));

    // Record the stop event
    checkCudaErrors(cudaEventRecord(stop, stream));
    printf("test done !!!\n");

    // Wait for the stop event to complete
    checkCudaErrors(cudaEventSynchronize(stop));

    float msecTotal = 0.0f;
    checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));

    // Compute and print the performance
#if RUN_FLASH_ATTN
    printf("flash attention elapsed time = %.3f msec\n", msecTotal);
#else
    printf("normal approach elapsed time = %.3f msec\n", msecTotal);
#endif // RUN_FLASH_ATTN
    // Copy result from device to host
    checkCudaErrors(
        cudaMemcpyAsync(h_O, d_O, size_O * sizeof(half), cudaMemcpyDeviceToHost, stream));
    checkCudaErrors(cudaStreamSynchronize(stream));

    printf("Checking computed result for correctness: \n");

    double eps = 0.01; // 1% error with python output
    const unsigned int max_print_count = 100;
    uint32_t total_count = 0;
    uint32_t total_err_count = 0;
    for (int i = 0; i < static_cast<int>(size_O); i++)
    {
        double gold_val = fabs((double)h_gold_O[i]);
        double abs_val = fabs((double)h_O[i]);
        double abs_err = fabs(abs_val - gold_val);
        double rel_err = abs_err / abs_val;

        if (rel_err > eps)
        {
            if (total_err_count < max_print_count)
                printf("Error! Matrix[%05d]=%.8f, ref=%.8f error term %E is > %E\n",
                       i, (double)h_O[i], (double)h_gold_O[i], rel_err, eps);
            total_err_count++;
        }
        total_count++;
    }
    double error_ratio = (double)total_err_count / (double)total_count;
    bool correct = error_ratio < eps;
    printf("total count %u total error count %u (%.8f %%)\n", total_count, total_err_count, error_ratio * 100);
    printf("%s\n", correct ? "Result = PASS" : "Result = FAIL");

    // Clean up memory
    checkCudaErrors(cudaFree(d_Q));
    checkCudaErrors(cudaFree(d_K));
    checkCudaErrors(cudaFree(d_V));
    checkCudaErrors(cudaFree(d_O));
    checkCudaErrors(cudaEventDestroy(start));
    checkCudaErrors(cudaEventDestroy(stop));

    if (correct)
    {
        return EXIT_SUCCESS;
    }
    else
    {
        return EXIT_FAILURE;
    }
}

inline bool file_exists(const std::string &name)
{
    struct stat buffer;
    return (stat(name.c_str(), &buffer) == 0);
}

void load_data(std::vector<half> &matrix, const std::string bin_file)
{
    printf("loading %s\n", bin_file.c_str());
    assert(file_exists(bin_file) && "Error! binary file doesn't exist");

    std::ifstream fin(bin_file, std::ios::binary);
    half elem;
    while (fin.read(reinterpret_cast<char *>(&elem), sizeof(half)))
    {
        matrix.push_back(elem);
    }
}

int main(int argc, char **argv)
{
    printf("[Flash Attention Using CUDA] - Starting...\n");

    if (checkCmdLineFlag(argc, (const char **)argv, "help") ||
        checkCmdLineFlag(argc, (const char **)argv, "?"))
    {

        printf("Usage -device=n (n >= 0 for deviceID)\n");
        printf("      -BATCH=number of Batch\n");
        printf("      -HEAD=number of Head\n");
        printf("      -BLK=block size\n");
        printf("      -THD=thread size\n");
        exit(EXIT_SUCCESS);
    }

    // This will pick the best possible CUDA capable device, otherwise
    // override the device ID based on input provided at the command line
    int dev = findCudaDevice(argc, (const char **)argv);

    unsigned int batch = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "BATCH"))
    {
        batch = getCmdLineArgumentInt(argc, (const char **)argv, "BATCH");
    }
    unsigned int head = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "HEAD"))
    {
        head = getCmdLineArgumentInt(argc, (const char **)argv, "HEAD");
    }
    unsigned int block_size = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "BLK"))
    {
        block_size = getCmdLineArgumentInt(argc, (const char **)argv, "BLK");
    }
    unsigned int thread_size = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "THD"))
    {
        thread_size = getCmdLineArgumentInt(argc, (const char **)argv, "THD");
    }

    // load Q
    std::vector<half> mat_Q;
    load_data(mat_Q, "./mat_Q.bin");

    // load K
    std::vector<half> mat_K;
    load_data(mat_K, "./mat_K.bin");

    // load V
    std::vector<half> mat_V;
    load_data(mat_V, "./mat_V.bin");

    // load golden data O
    std::vector<half> gold_mat_O;
    load_data(gold_mat_O, "./gold_mat_O.bin");

    printf("batch %u head %u block_size %u thread_size %u\n", batch, head, block_size, thread_size);

    printf("Q size %lu K size %lu V size %lu O size %lu\n", mat_Q.size(), mat_K.size(), mat_V.size(), gold_mat_O.size());

    checkCudaErrors(cudaProfilerStart());
    int result = flash_attention_test(
        batch, head, block_size, thread_size,
        &mat_Q[0], mat_Q.size(),
        &mat_K[0], mat_K.size(),
        &mat_V[0], mat_V.size(),
        &gold_mat_O[0], gold_mat_O.size());
    checkCudaErrors(cudaProfilerStop());

    exit(result);
}

Here are the flash_attention_kernel_func.h and normal_transformer_func.h, respectively.

// flash_attention_kernel_func.h
__global__ void __launch_bounds__(1024) flash_attention_kernel(half *__restrict__ Q, half *__restrict__ K, half *__restrict__ V, half *__restrict__ O)
{
    for (int32_t i = 0; (i < 4); i = (i + 1))
    {
        O[(((((int)blockIdx.x * 32) + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i))] = ((half)(0));
    }
    __syncthreads();
    __shared__ half smem_q[4096];
    __shared__ half smem_k[4096];
    __shared__ half smem_v[4096];
    __shared__ half smem_o[4096];
    __shared__ half smem_l[32];
    __shared__ half smem_l_local[32];
    __shared__ half smem_l_new[32];
    __shared__ half smem_m[32];
    __shared__ half smem_m_local[32];
    __shared__ half smem_m_new[32];
    __shared__ half smem_sp[1024];
    __shared__ half smem_pv[4096];
    __shared__ half smem_temp[1024];
    for (int32_t i_1 = 0; (i_1 < 4); i_1 = (i_1 + 1))
    {
        smem_q[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_1))] = Q[(((((int)blockIdx.x * 32) + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_1))];
        smem_o[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_1))] = ((half)(0));
        smem_l[((int)threadIdx.x / 32)] = ((half)(0));
        smem_m[((int)threadIdx.x / 32)] = ((half)((-65504)));
    }
    __syncthreads();
    __syncthreads();
    for (int32_t j = 0; (j < 16); j = (j + 1))
    {
        for (int32_t i_2 = 0; (i_2 < 4); i_2 = (i_2 + 1))
        {
            int32_t offset_j = (j * 32);
            smem_k[((((((int)threadIdx.x % 32) * 4) + i_2) * 32) + ((int)threadIdx.x / 32))] = K[(((offset_j + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_2))];
            smem_v[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_2))] = V[(((offset_j + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_2))];
        }
        __syncthreads();
        __syncthreads();
        half *A = smem_q;
        half *B = smem_k;
        half *C = smem_sp;
        C[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = ((half)(0.0f));
        __syncthreads();
        for (int32_t i_3 = 0; (i_3 < 128); i_3 = (i_3 + 1))
        {
            atomicAdd(&C[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))], (A[((((int)threadIdx.x / 32) * 128) + i_3)] * B[((i_3 * 32) + ((int)threadIdx.x % 32))]));
        }
        __syncthreads();
        __syncthreads();
        half *A_1 = smem_sp;
        half *M = smem_m_local;
        half *T = smem_temp;
        T[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = A_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))];
        __syncthreads();
        int32_t k = 1;
        while ((k < 32))
        {
            if ((((int)threadIdx.x % 32) % (k * 2)) == 0)
            {
                T[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = __hmax(T[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))], T[((((int)threadIdx.x / 32) * 32) + (((int)threadIdx.x % 32) + k))]);
            }
            __syncthreads();
            k = (k * 2);
        }
        if ((int)threadIdx.x < 32)
        {
            M[((int)threadIdx.x % 32)] = T[(((int)threadIdx.x % 32) * 32)];
        }
        __syncthreads();
        __syncthreads();
        half *S = smem_sp;
        half *M_1 = smem_m_local;
        S[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = hexp((S[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] - M_1[((int)threadIdx.x / 32)]));
        __syncthreads();
        __syncthreads();
        half *A_2 = smem_sp;
        half *L = smem_l_local;
        half *T_1 = smem_temp;
        T_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = A_2[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))];
        __syncthreads();
        int32_t k_1 = 1;
        while ((k_1 < 32))
        {
            if ((((int)threadIdx.x % 32) % (k_1 * 2)) == 0)
            {
                T_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = (T_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] + T_1[((((int)threadIdx.x / 32) * 32) + (((int)threadIdx.x % 32) + k_1))]);
            }
            __syncthreads();
            k_1 = (k_1 * 2);
        }
        if ((int)threadIdx.x < 32)
        {
            L[((int)threadIdx.x % 32)] = T_1[(((int)threadIdx.x % 32) * 32)];
        }
        __syncthreads();
        __syncthreads();
        half *M_2 = smem_m;
        half *M_new = smem_m_new;
        half *M_local = smem_m_local;
        half *L_1 = smem_l;
        half *L_new = smem_l_new;
        half *L_local = smem_l_local;
        if ((int)threadIdx.x < 32)
        {
            M_new[((int)threadIdx.x % 32)] = __hmax(M_2[((int)threadIdx.x % 32)], M_local[((int)threadIdx.x % 32)]);
            L_new[((int)threadIdx.x % 32)] = ((hexp((M_2[((int)threadIdx.x % 32)] - M_new[((int)threadIdx.x % 32)])) * L_1[((int)threadIdx.x % 32)]) + (hexp((M_local[((int)threadIdx.x % 32)] - M_new[((int)threadIdx.x % 32)])) * L_local[((int)threadIdx.x % 32)]));
        }
        __syncthreads();
        __syncthreads();
        half *A_3 = smem_sp;
        half *B_1 = smem_v;
        half *C_1 = smem_pv;
        for (int32_t i_4 = 0; (i_4 < 4); i_4 = (i_4 + 1))
        {
            C_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_4))] = ((half)(0.0f));
        }
        __syncthreads();
        for (int32_t i_5 = 0; (i_5 < 32); i_5 = (i_5 + 1))
        {
            for (int32_t i_6 = 0; (i_6 < 4); i_6 = (i_6 + 1))
            {
                atomicAdd(&C_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_6))], (A_3[((((int)threadIdx.x / 32) * 32) + i_5)] * B_1[((i_5 * 128) + ((((int)threadIdx.x % 32) * 4) + i_6))]));
            }
        }
        __syncthreads();
        __syncthreads();
        half *PV = smem_pv;
        half *O_1 = smem_o;
        half *M_local_1 = smem_m_local;
        half *M_new_1 = smem_m_new;
        half *M_3 = smem_m;
        half *L_new_1 = smem_l_new;
        half *L_2 = smem_l;
        for (int32_t i_7 = 0; (i_7 < 4); i_7 = (i_7 + 1))
        {
            O_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_7))] = (((((half)(powf((float)(L_new_1[((int)threadIdx.x / 32)]), ((float)((-1)))))) * (L_2[((int)threadIdx.x / 32)] * hexp((M_3[((int)threadIdx.x / 32)] - M_new_1[((int)threadIdx.x / 32)])))) * O_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_7))]) + (hexp((M_local_1[((int)threadIdx.x / 32)] - M_new_1[((int)threadIdx.x / 32)])) * PV[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_7))]));
        }
        __syncthreads();
        if ((j + 1) == 16)
        {
            for (int32_t i_8 = 0; (i_8 < 4); i_8 = (i_8 + 1))
            {
                O[(((((int)blockIdx.x * 32) + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_8))] = smem_o[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_8))];
            }
            __syncthreads();
        }
        if ((int)threadIdx.x < 32)
        {
            smem_m[((int)threadIdx.x % 32)] = smem_m_new[((int)threadIdx.x % 32)];
            smem_l[((int)threadIdx.x % 32)] = smem_l_new[((int)threadIdx.x % 32)];
        }
        __syncthreads();
    }
    __syncthreads();
    return;
}
// normal_transformer_func.h
__global__ void __launch_bounds__(512) hidet_compute_GLOBAL_QK(half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ GLOBAL_QK) {
  half acc_Sum = half(0.0f);
  for (int32_t v = 0; (v < 128); v = (v + 1)) {
    acc_Sum = (acc_Sum + (GLOBAL_Q[(((int)blockIdx.x * 128) + v)] * GLOBAL_K[(((int)threadIdx.x * 128) + v)]));
  } 
  GLOBAL_QK[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = acc_Sum;
}

__global__ void __launch_bounds__(512) hidet_compute_S(half * __restrict__ GLOBAL_QK, half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ S) {
  half acc_Max = half(-65504.0f);
  for (int32_t v = 0; (v < 512); v = (v + 1)) {
    acc_Max = __hmax(acc_Max, GLOBAL_QK[(((int)blockIdx.x * 512) + v)]);
  } 
  S[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = (GLOBAL_QK[(((int)blockIdx.x * 512) + (int)threadIdx.x)] - acc_Max);
}

__global__ void __launch_bounds__(512) hidet_compute_exp_s(half * __restrict__ S, half * __restrict__ exp_s) {
  exp_s[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = hexp(S[(((int)blockIdx.x * 512) + (int)threadIdx.x)]);
}

__global__ void __launch_bounds__(512) hidet_compute_softmax(half * __restrict__ S, half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ GLOBAL_QK, half * __restrict__ exp_s, half * __restrict__ softmax) {
  half acc_Sum = half(0.0f);
  for (int32_t v = 0; (v < 512); v = (v + 1)) {
    acc_Sum = (acc_Sum + hexp(S[(((int)blockIdx.x * 512) + v)]));
  } 
  softmax[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = (hexp(S[(((int)blockIdx.x * 512) + (int)threadIdx.x)]) / acc_Sum);
}

__global__ void __launch_bounds__(512) hidet_compute_GLOBAL_O(half * __restrict__ softmax, half * __restrict__ GLOBAL_V, half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ GLOBAL_QK, half * __restrict__ S, half * __restrict__ exp_s, half * __restrict__ GLOBAL_O) {
  half acc_Sum = half(0.0f);
  for (int32_t v = 0; (v < 512); v = (v + 1)) {
    acc_Sum = (acc_Sum + (softmax[((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 128) * 512) + v)] * GLOBAL_V[((v * 128) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 128))]));
  } 
  GLOBAL_O[((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 128) * 128) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 128))] = acc_Sum;
}

Again, really wonderful work on Hidet! And any help will be well appreciated 🙏 Or if any further info. is needed, please let me know.

Hi @keneoneth,

Very glad to see that you have tried Hidet and use it to implement flash attention, even before we write the documentation for Hidet Script!

You mentioned you have tried the "dlopen" method to load the generated "lib.so", can I know what error have you encountered?

To use hidet generated kernels in other packages,

  1. We current recommand you use the CompiledTask API (see hidet.runtime.compiled_task). You can first build the task and store the compiled kernel in one directory. In subsequent runs, you can load the compiled task using hidet.runtime.compiled_task.load_compiled_task(...) to load the compiled task from the disk. You can pass pytorch tensor or hidet tensor to the CompiledTask to launch the kernel and get the results. Currently, we only support returning hidet tensor. In the future, we can add the support to return pytorch tensor when specified.
  2. If you want to use Hidet generated kernels in other C++ project, we recommand you directly load the compiled dynamic library (i.e., lib.so) using dlopen, and also load the hidet runtime library (i.e., hidet_runtime.so) to have fine-grained control of the operator execution (e.g., control the cuda stream to launch the kernel on, register call back functions for memory management).
  3. The third method is what you did, but it is not recommanded, at least for now.

FYI: my colleage @hjjq has written a flash attention in hidet (see https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/ops/attention/attention.py) and it can be a reference of attention implementation.

If you only want to benchmark the performance, it is recommanded to directly do this in python, and it would be much easier.

Let me know if you have other questions in using hidet and I am happy to help.

Thanks @yaoyaoding for your help, I took a look at method 2) the source code of hidet_runtime.so and updated my code as below. The kernels can work properly now👍.

This is my updated flash_attention_example.py

import os
import math
import time
import numpy as np
import torch
import torch.nn as nn
torch.manual_seed(123)

# NOTE: this script is a simplified implementation of the following research work using Hidet
# Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35, 16344-16359.
# link to paper: https://arxiv.org/abs/2205.14135

import hidet
from hidet.ir.compute import compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule
from hidet.ir.primitives.cuda.atomic import atomic_add
from hidet.lang import f16, spatial, repeat, tensor, attr, grid, printf
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.graph.ops.definitions.utils import input_like
from hidet.ir.expr import cast, address
from hidet.ir.primitives import exp, max, printf
from hidet import driver
from hidet.runtime.module import CompiledTaskCache

HIDET_BUILD_PATH = os.path.join(os.environ['HIDET_HOME'],"build/lib")

# define Flash Attention Task
class FlashAttentionTask(Task):
    
    def allow_epilogue(self) -> bool:
        return False

    def flash_attention_implement_cuda(self, working_dir: str) -> IRModule:
        # override this method to use template-based scheduling
        return flash_attention_schedule(self)
    
    # Require: Matrices Q�K�V Nxd in HBM, on-chip SRAM of size M.
    # NOTE: typical SRAM size 100 kB, default to 48 kB
    # NOTE: max thread num is set to 1024
    def __init__(self,N=512,d=128,H=16,B=1,M=48*1024,ratio=12,max_thread_num=1024,disable_flash_attention=False):

        # 1. set block sizes Bc = ceil(M/(4d)), Br = min(M/(4d),d)
        Bc = math.ceil(M/(ratio*d))
        Br = min(math.ceil(M/(ratio*d)),d)
        Tr = math.ceil(N/Br)
        Tc = math.ceil(N/Bc)
        GLOBAL_Q = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_Q')
        GLOBAL_K = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_K')
        GLOBAL_V = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_V')
        
        def normal_transformer():
            matmulQK = compute(
                    name = 'GLOBAL_QK',
                    shape = [N, N],
                    fcompute = lambda i, j: reduce(
                        shape=[d],
                        fcompute=lambda k: GLOBAL_Q[i, k] * GLOBAL_K[j, k],
                        reduce_type='sum',
                    )
                )

            max_val = lambda i : reduce(shape=[N], fcompute=lambda j: matmulQK[i,j], reduce_type='max')
            S = compute(
                    name = 'S',
                    shape = [N, N],
                    fcompute = lambda i,j: matmulQK[i,j] - max_val(i)
                )
            exp_s = compute(
                    name = 'exp_s',
                    shape = [N, N],
                    fcompute = lambda i,j: exp(S[i,j])
                )
            exp_sum = lambda i : reduce(shape=[N], fcompute=lambda j: exp_s[i,j], reduce_type='sum')
            softmax = compute('softmax', shape=[N,N], fcompute=lambda i,j: exp_s[i,j] / exp_sum(i))
            matmulPV = compute(
                    name = 'GLOBAL_O',
                    shape = [N, d],
                    fcompute = lambda i, j: reduce(
                        shape=[N],
                        fcompute=lambda k: softmax[i, k] * GLOBAL_V[k, j],
                        reduce_type='sum',
                    )
                )
            return matmulPV
        
        super().__init__(
            name='flash_attention_task',
            inputs=[GLOBAL_Q,GLOBAL_K,GLOBAL_V],
            outputs=[normal_transformer()],
            attributes={
                'B' : B,
                'H' : H,
                'N' : N,
                'd' : d,
                'Bc' : Bc,
                'Br' : Br,
                'Tc' : Tc,
                'Tr' : Tr,
                'BLK' : Tr,
                'THD' : Br * Bc,
                'MAX_THD' : max_thread_num
            },
        )
        if not disable_flash_attention:
            self.implement_cuda = self.flash_attention_implement_cuda
            self.define = f'-DRUN_FLASH_ATTN -DHIDET_BUILD_PATH=\\"{HIDET_BUILD_PATH}\\"'
        else:
            self.define = f'-DHIDET_BUILD_PATH=\\"{HIDET_BUILD_PATH}\\"'

# define custom schedule
def flash_attention_schedule(task:FlashAttentionTask) -> IRModule:
    
    print_debug = False

    B = task.attrs['B']
    H = task.attrs['H']
    N = task.attrs['N']
    d = task.attrs['d']
    Bc = task.attrs['Bc']
    Br = task.attrs['Br']
    Tr = task.attrs['Tr']
    Tc = task.attrs['Tc']

    dims = ( task.attrs['BLK'] )
    threads = task.attrs['THD']
    assert threads <= task.attrs['MAX_THD'], f'err: {threads} not < {task.attrs["MAX_THD"]}'
    assert d % Bc == 0, f'err: Bc is not divisible by d'
    assert d % Br == 0, f'err: Br is not divisible by d'


    largest_fp16_value = 65504

    print(f'task.attrs {task.attrs}')
    
    
    # define the tensor program
    with hidet.script_module() as module:
        """Flash attention kernel."""

        @hidet.script
        def QK_matmul_compute(A:f16[Br,d],B:f16[d,Bc],C:f16[Br,Bc]):
            for m,n in spatial(Br,Bc).on(threadIdx.x):
                C[m,n] = 0.0
            syncthreads()
            for m,k,n in spatial(Br,1,Bc).repeat(1,d,1).on(threadIdx.x):   
                atomic_add(~C[m,n],A[m,k] * B[k,n])
            syncthreads()

        @hidet.script
        def PV_matmul_compute(A:f16[Br,Bc],B:f16[Bc,d],C:f16[Br,d]):
            for m,n in spatial(Br,Bc).repeat(1,d//Bc).on(threadIdx.x):
                C[m,n] = 0.0
            syncthreads()
            for m,k,n in spatial(Br,1,Bc).repeat(1,Bc,d//Bc).on(threadIdx.x):   
                atomic_add(~C[m,n],A[m,k] * B[k,n])
            syncthreads()

        @hidet.script
        def rowmax_compute(A:f16[Br,Bc],M:f16[Br],T:f16[Br,Bc]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                T.write([i,j],A[i,j],protected=True)
            syncthreads()

            for i,j in spatial(Br,Bc).on(threadIdx.x):
                k = 1
                while k < Bc:
                    if j % (k*2) == 0:
                        T.write([i,j],max(T[i,j],T[i,j+k]),protected=True)
                    syncthreads()
                    k *= 2

            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    M[i] = T[i,0]
            syncthreads()

        @hidet.script
        def rowsum_compute(A:f16[Br,Bc],L:f16[Br],T:f16[Br,Bc]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                T.write([i,j],A[i,j],protected=True)
            syncthreads()

            for i,j in spatial(Br,Bc).on(threadIdx.x):
                k = 1
                while k < Bc:
                    if j % (k*2) == 0:
                        T.write([i,j],(T[i,j]+T[i,j+k]),protected=True)
                    syncthreads()
                    k *= 2

            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    L[i] = T[i,0]
            syncthreads()

        @hidet.script
        def local_softmax_compute(S:f16[Br,Bc],M:f16[Br]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                if False and blockIdx.x==0:
                    printf("S[i,j] before %d %d %d %d : %f - %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"),cast(M[i],"float32"))
                S[i,j] = exp(S[i,j] - M[i])
                if False and blockIdx.x==0:
                    printf("S[i,j] %d %d %d %d : %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"))
            syncthreads()
        
        @hidet.script
        def local_update_compute(M:f16[Br],M_new:f16[Br],M_local:f16[Br],L:f16[Br],L_new:f16[Br],L_local:f16[Br]):
            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    M_new[i] = max(M[i],M_local[i])
                    L_new[i] = exp(M[i] - M_new[i]) * L[i] + exp(M_local[i] - M_new[i]) * L_local[i]
            syncthreads()

        @hidet.script
        def global_update_compute(PV:f16[Br,d],O:f16[Br,d],M_local:f16[Br],M_new:f16[Br],M:f16[Br],L_new:f16[Br],L:f16[Br]):
            for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                O.write(
                    [i,j],
                    ((L_new[i]**-1) * (L[i]*exp(M[i]-M_new[i])) * O[i,j]) + (exp(M_local[i]-M_new[i]) * PV[i,j]),
                    protected=True
                )
            syncthreads()

        @hidet.script
        def flash_attention_kernel(
            Q: f16[N,d],
            K: f16[N,d],
            V: f16[N,d],
            O: f16[N,d]
        ):
            
            attr.cuda_grid_dim = dims
            attr.cuda_block_dim = threads

            # Init O=(0), N x d in HBM
            for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                offset_i = blockIdx.x * (Br)
                O[offset_i:,:].write([i,j], 0, protected=True)
            syncthreads()

            smem_q = tensor('shared', 'float16', [Br, d])
            smem_k = tensor('shared', 'float16', [d, Bc]) # transposed
            smem_v = tensor('shared', 'float16', [Bc, d])
            smem_o = tensor('shared', 'float16', [Br, d])
            
            smem_l = tensor('shared', 'float16', [Br])
            smem_l_local = tensor('shared', 'float16', [Br])
            smem_l_new = tensor('shared', 'float16', [Br])
            smem_m = tensor('shared', 'float16', [Br])
            smem_m_local = tensor('shared', 'float16', [Br])
            smem_m_new = tensor('shared', 'float16', [Br])
            smem_sp = tensor('shared', 'float16', [Br,Bc])
            smem_pv = tensor('shared', 'float16', [Br,d])
            smem_temp = tensor('shared', 'float16', [Br,Bc])

            for a,b in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                # load Qi from HBM to on-chip SRAM
                # initialization of o,l,m
                offset_i = blockIdx.x * (Br)
                smem_q[a,b] = Q[offset_i:,:].read([a,b],protected=True)
                smem_o[a,b] = 0
                smem_l[a] = 0
                smem_m[a] = -largest_fp16_value
            syncthreads()

            if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                idx = 0
                for i,j in grid(Br,d):
                    printf("idx: %d, Q val: %f\n",idx,cast(smem_q[i,j],"float32"))
                    idx += 1
            syncthreads()

            for j in grid(Tc):

                for a,b in spatial(Bc,Br).repeat(1,(d//Br)).on(threadIdx.x):
                    # load Kj,Vj from HBM to on-chip SRAM
                    offset_j = j * (Bc)
                    smem_k[b,a] = K[offset_j:,:].read([a,b],protected=True)
                    smem_v[a,b] = V[offset_j:,:].read([a,b],protected=True)
                syncthreads()
                
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(d,Bc):
                        printf("idx: %d, K val: %f\n",idx,cast(smem_k[i,j],"float32"))
                        idx += 1
                    for i,j in grid(Bc,d):
                        printf("idx: %d, V val: %f\n",idx,cast(smem_v[i,j],"float32"))
                        idx += 1
                syncthreads()
                
                # on chip, compute Sij = Qi @ (Kj)^T, Br X Bc
                QK_matmul_compute(smem_q,smem_k,smem_sp)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,Bc):
                        printf("idx: %d, S val: %f\n",idx,cast(smem_sp[i,j],"float32"))
                        idx += 1
                syncthreads()
                
                # on chip, compute m'_ij = rowmax(Sij), Br; Pij = exp(Sij - m'_ij), Br x Bc (pointwise); l'_ij = rowsum(P'ij), Br
                rowmax_compute(smem_sp,smem_m_local,smem_temp)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d M val: %f\n",i,cast(smem_m_local[i],"float32"))
                        # for j in grid(Bc):
                        #     printf("j: %d, S val: %f\n",j,cast(smem_sp[i,j],"float32"))
                syncthreads()
                
                local_softmax_compute(smem_sp,smem_m_local)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,Bc):
                        printf("idx: %d, P val: %f\n",idx,cast(smem_sp[i,j],"float32"))
                        idx += 1
                syncthreads()


                rowsum_compute(smem_sp,smem_l_local,smem_temp)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d L val: %f\n",i,cast(smem_l_local[i],"float32"))
                        # for j in grid(Bc):
                        #     printf("j: %d, P val: %f\n",j,cast(smem_sp[i,j],"float32"))
                syncthreads()

                
                # on chip, compute m_new_i = max(m_i,m'_ij), Br; l_new_i = e^(m_i - m_new_i) * l_i + e^(m'_ij - m_i_new) * l'_ij, Br
                local_update_compute(smem_m,smem_m_new,smem_m_local,smem_l,smem_l_new,smem_l_local)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d smem_m val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_m_new val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_m_local val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l_new val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l_local val: %f\n",i,cast(smem_m[i],"float32"))
                syncthreads()
                # write Oi = diag(l_i_new)^-1 * (diag(l_i)*e^(m_i-m_i_new) @ Oi + e^*m'_ij-m_i_new*(P'ij @ Vj))

                PV_matmul_compute(smem_sp,smem_v,smem_pv)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,d):
                        printf("idx: %d, PV val: %f\n",idx,cast(smem_pv[i,j],"float32"))
                        idx += 1
                syncthreads()

                global_update_compute(smem_pv,smem_o,smem_m_local,smem_m_new,smem_m,smem_l_new,smem_l)

                if j + 1 == Tc:
                    for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                        offset_i = blockIdx.x * (Br)
                        O[offset_i:,:].write([i,j], smem_o[i,j], protected=True)
                    syncthreads()

                # write l_i = l_i_new, m_i = m_i_new
                for i in spatial(Br).on(threadIdx.x):
                    if threadIdx.x < Br:
                        smem_m[i] = smem_m_new[i]
                        smem_l[i] = smem_l_new[i]
                syncthreads()

            if print_debug and (blockIdx.x==15 and threadIdx.x==0):
                idx = 0
                for i,j in grid(Br,d):
                    offset_i = blockIdx.x * (Br)
                    printf("blockIdx %d : output idx: %d, val: %f\n",blockIdx.x,idx,cast(O[offset_i+i,j],"float32"))
                    idx += 1
            syncthreads()
            return

        @hidet.script
        def flash_attention_launch_func( 
            G_Q: f16[B, H, N, d],
            G_K: f16[B, H, N, d],
            G_V: f16[B, H, N, d],
            G_O: f16[B, H, N, d]
        ):
            # NOTE: this section needs to be written in flash_attention_main.cu
            for b,h in grid(B,H):
                flash_attention_kernel(
                    address(G_Q[b,h,0,0]),
                    address(G_K[b,h,0,0]),
                    address(G_V[b,h,0,0]),
                    address(G_O[b,h,0,0])
                )
            
    # build ir module
    ir_module = module.ir_module()
    return ir_module

# gen Python gold data as reference
def gen_gold(attrs,r1=-3,r2=3):

    Q = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    K = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    V = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    t = time.process_time()
    Q.half().numpy().tofile('mat_Q.bin')
    K.half().numpy().tofile('mat_K.bin')
    V.half().numpy().tofile('mat_V.bin')
    S = torch.from_numpy(Q.numpy() @ torch.transpose(K, -2, -1).numpy())

    row_max, _ = torch.max(S,dim=-1)
    S = torch.from_numpy(np.exp((S - row_max.reshape(attrs['B'],attrs['H'],attrs['N'],1)).numpy()))
    row_sum = torch.sum(S,dim=-1).reshape(attrs['B'],attrs['H'],attrs['N'],1)
    P = S / row_sum

    # TODO: test with softmax float precision
    # P = nn.Softmax(dim=-1)(S.float()).half()
    O = torch.from_numpy(P.numpy() @ V.numpy())
    elapsed_time = (time.process_time() - t)*1000
    print(f"Python gold gen run elapsed time {round(elapsed_time,3)} msec")
    O.half().numpy().tofile('gold_mat_O.bin')

# run task
def run_task(disable_flash_attention=False):
    # clear cache
    driver.compiled_task_cache = CompiledTaskCache()
    # define the task here
    flash_attention_task = FlashAttentionTask(disable_flash_attention=disable_flash_attention)
    # build the task
    ret = flash_attention_task.build('cuda')

    # copy source file and lib to current directory
    source_path = ret.src_path
    library_path = ret.lib_path
    print(f'source_path {source_path} library_path {library_path}')

    import shutil
    shutil.move(source_path,os.path.join("./","flash_attention_"+os.path.basename(source_path)))
    shutil.move(library_path,os.path.join("./","flash_attention_"+os.path.basename(library_path)))

    # generate golden data
    gen_gold(flash_attention_task.attrs)

    def exe_f(command='', shell=True):
        print(f'running {command}')
        import subprocess
        process = subprocess.Popen(command, shell=shell)
        code = process.wait()
        process.communicate()
        return code
    
    # launch testcase flash_attention_main.cu
    HIDET_CUDA_INCLUDE_PATH = "../cuda-samples-master/Common/"
    CUDA_SAMPLES_INCLUDE_PATH = "../../include/"
    ret = exe_f(f'nvcc flash_attention_main.cu {flash_attention_task.define} -gencode arch=compute_86,code=sm_86 -I {CUDA_SAMPLES_INCLUDE_PATH} -I {HIDET_CUDA_INCLUDE_PATH} -std=c++11 -o fa.out && ./fa.out -BATCH={flash_attention_task.attrs["B"]} -HEAD={flash_attention_task.attrs["H"]} -BLK={flash_attention_task.attrs["BLK"]} -THD={flash_attention_task.attrs["THD"]}')
    print('test done' if ret==0 else 'test error')

# main function
if __name__ == '__main__':
    # normal approach execution
    run_task(disable_flash_attention=True)
    # flash attention approach execution
    run_task(disable_flash_attention=False)

And here's my flash_attention_main.cu, the code written to do hidet_launch

// System includes
#include <stdio.h>
#include <sys/stat.h>
#include <dlfcn.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <assert.h>
#include <string>
#include <vector>

// CUDA runtime
#include <cuda_runtime.h>
#include <cuda_profiler_api.h>

// Helper functions and utilities to work with CUDA,
#include <helper_functions.h>
#include <helper_cuda.h>
#include <cuda_fp16.h>


typedef void (*hidet_launch_t)(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args);
typedef void (*set_cuda_stream_t)(void*);
typedef void (*register_callback_t)(const char* name, void *func_ptr);
// typedef uint8_t* (*cudaMallocCallback_t)(uint64_t nbytes);

void *cuda_lib_handle;
void *runtime_lib_handle;

hidet_launch_t hidet_launch_func;
void * allocate_cuda_storage_func;


uint8_t * cudaMallocCallback(uint64_t nbytes) {

    uint8_t * buffer;
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&buffer), nbytes));
    return buffer;
}
    

void set_kernel_func() {
    printf("setting kernel func ...\n");
    char *error;
    cuda_lib_handle = dlopen("./flash_attention_lib.so", RTLD_LAZY | RTLD_LOCAL);
    if (!cuda_lib_handle)
    {
        fprintf(stderr, "%s\n", dlerror());
        exit(EXIT_FAILURE);
    }

    dlerror();

    hidet_launch_func = (hidet_launch_t) dlsym(cuda_lib_handle, "hidet_launch");

    if ((error = dlerror()) != NULL)
    {
        fprintf(stderr, "%s\n", error);
        exit(EXIT_FAILURE);
    }
}

// #define STRINGIFY(x) #x
void setup_kernel_run(cudaStream_t & stream) {
    printf("setting up kernel run ...\n");
    char *error;

    std::string libpath = HIDET_BUILD_PATH "/libhidet_runtime.so";
    printf("hidet runtime libpath %s", libpath.c_str());
    runtime_lib_handle = dlopen(libpath.c_str(), RTLD_LAZY | RTLD_LOCAL);
    if (!runtime_lib_handle)
    {
        fprintf(stderr, "%s\n", dlerror());
        exit(EXIT_FAILURE);
    }

    dlerror();

    set_cuda_stream_t set_cuda_stream_func = (set_cuda_stream_t) dlsym(runtime_lib_handle, "set_cuda_stream");

    if ((error = dlerror()) != NULL)
    {
        fprintf(stderr, "%s\n", error);
        exit(EXIT_FAILURE);
    }
    // set stream
    set_cuda_stream_func(stream);


    register_callback_t register_callback_func = (register_callback_t) dlsym(runtime_lib_handle, "register_callback");
    if ((error = dlerror()) != NULL)
    {
        fprintf(stderr, "%s\n", error);
        exit(EXIT_FAILURE);
    }

    assert (register_callback_func!=nullptr);

    register_callback_func("allocate_cuda_storage", (void *) (cudaMallocCallback));

    register_callback_func("cuda_memset", (void *) (cudaMemset));
}



// set_cuda_stream

// test function, execute kernel, compare with gold data
int flash_attention_test(
    unsigned int B, unsigned int H,
    unsigned int block_size, unsigned int thread_size,
    half *h_Q, unsigned int size_Q,
    half *h_K, unsigned int size_K,
    half *h_V, unsigned int size_V,
    half *h_gold_O, unsigned int size_O)
{

    
    
    // set up kernel function
    set_kernel_func();
    assert(hidet_launch_func!=nullptr);
    
    cudaStream_t stream;
    const unsigned int BH = B * H;
    // Allocate device memory
    half *d_Q, *d_K, *d_V, *d_O, *h_O;
    checkCudaErrors(cudaMallocHost(&h_O, size_O * sizeof(half)));

    if (h_O == NULL)
    {
        fprintf(stderr, "Failed to allocate host matrix O!\n");
        exit(EXIT_FAILURE);
    }

    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_Q), size_Q * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_K), size_K * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_V), size_V * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_O), size_O * sizeof(half)));
    // Allocate CUDA events that we'll use for timing
    cudaEvent_t start, stop;
    checkCudaErrors(cudaEventCreate(&start));
    checkCudaErrors(cudaEventCreate(&stop));

    checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));

    // copy host memory to device
    checkCudaErrors(
        cudaMemcpyAsync(d_Q, h_Q, size_Q * sizeof(half), cudaMemcpyHostToDevice, stream));
    checkCudaErrors(
        cudaMemcpyAsync(d_K, h_K, size_K * sizeof(half), cudaMemcpyHostToDevice, stream));
    checkCudaErrors(
        cudaMemcpyAsync(d_V, h_V, size_V * sizeof(half), cudaMemcpyHostToDevice, stream));

    const unsigned int k_size_Q = (size_Q / BH);
    const unsigned int k_size_K = (size_K / BH);
    const unsigned int k_size_V = (size_V / BH);
    const unsigned int k_size_O = (size_O / BH);
    printf("k_size_Q %u k_size_K %u k_size_V %u k_size_O %u\n", k_size_Q, k_size_K, k_size_V, k_size_O);

    // prepare setup for kernel run
    setup_kernel_run(stream);

    // Record the start event
    checkCudaErrors(cudaEventRecord(start, stream));

    const int32_t num_args = 4;
    int32_t arg_types[num_args] = {3,3,3,3};

    for (unsigned int b = 0; b < B; b++)
    {
        for (unsigned int h = 0; h < H; h++)
        {
            unsigned int offset_index = (b * H) + h;

            half *param[num_args] = {
                d_Q + offset_index * k_size_Q,
                d_K + offset_index * k_size_K,
                d_V + offset_index * k_size_V,
                d_O + offset_index * k_size_O};

            void* args[num_args] = {param[0],param[1],param[2],param[3]};
            hidet_launch_func(num_args, arg_types, args);

        }
    }

    // stream sync
    checkCudaErrors(cudaStreamSynchronize(stream));

    // Record the stop event
    checkCudaErrors(cudaEventRecord(stop, stream));

    printf("test done !!!\n");

    // Wait for the stop event to complete
    checkCudaErrors(cudaEventSynchronize(stop));

    float msecTotal = 0.0f;
    checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));

    // Compute and print the performance
#if RUN_FLASH_ATTN
    printf("flash attention elapsed time = %.3f msec\n", msecTotal);
#else
    printf("normal approach elapsed time = %.3f msec\n", msecTotal);
#endif // RUN_FLASH_ATTN
    // Copy result from device to host
    checkCudaErrors(
        cudaMemcpyAsync(h_O, d_O, size_O * sizeof(half), cudaMemcpyDeviceToHost, stream));
    checkCudaErrors(cudaStreamSynchronize(stream));

    printf("Checking computed result for correctness: \n");

    double eps = 0.01; // 1% error with python output
    const unsigned int max_print_count = 100;
    uint32_t total_count = 0;
    uint32_t total_err_count = 0;
    for (int i = 0; i < static_cast<int>(size_O); i++)
    {
        double gold_val = fabs((double)h_gold_O[i]);
        double abs_val = fabs((double)h_O[i]);
        double abs_err = fabs(abs_val - gold_val);
        double rel_err = abs_err / abs_val;

        if (rel_err > eps)
        {
            if (total_err_count < max_print_count)
                printf("Error! Matrix[%05d]=%.8f, ref=%.8f error term %E is > %E\n",
                       i, (double)h_O[i], (double)h_gold_O[i], rel_err, eps);
            total_err_count++;
        }
        total_count++;
    }
    double error_ratio = (double)total_err_count / (double)total_count;
    bool correct = error_ratio < eps;
    printf("total count %u total error count %u (%.8f %%)\n", total_count, total_err_count, error_ratio * 100);
    printf("%s\n", correct ? "Result = PASS" : "Result = FAIL");

    // Clean up memory
    checkCudaErrors(cudaFree(d_Q));
    checkCudaErrors(cudaFree(d_K));
    checkCudaErrors(cudaFree(d_V));
    checkCudaErrors(cudaFree(d_O));
    checkCudaErrors(cudaEventDestroy(start));
    checkCudaErrors(cudaEventDestroy(stop));

    // close dynamic library
    if(cuda_lib_handle!=nullptr) dlclose(cuda_lib_handle);
    if(runtime_lib_handle!=nullptr) dlclose(runtime_lib_handle);

    if (correct)
    {
        return EXIT_SUCCESS;
    }
    else
    {
        return EXIT_FAILURE;
    }
}

inline bool file_exists(const std::string &name)
{
    struct stat buffer;
    return (stat(name.c_str(), &buffer) == 0);
}

void load_data(std::vector<half> &matrix, const std::string bin_file)
{
    printf("loading %s\n", bin_file.c_str());
    assert(file_exists(bin_file) && "Error! binary file doesn't exist");

    std::ifstream fin(bin_file, std::ios::binary);
    half elem;
    while (fin.read(reinterpret_cast<char *>(&elem), sizeof(half)))
    {
        matrix.push_back(elem);
    }
}

int main(int argc, char **argv)
{
    printf("[Flash Attention Using CUDA] - Starting...\n");

    if (checkCmdLineFlag(argc, (const char **)argv, "help") ||
        checkCmdLineFlag(argc, (const char **)argv, "?"))
    {

        printf("Usage -device=n (n >= 0 for deviceID)\n");
        printf("      -BATCH=number of Batch\n");
        printf("      -HEAD=number of Head\n");
        printf("      -BLK=block size\n");
        printf("      -THD=thread size\n");
        exit(EXIT_SUCCESS);
    }

    // This will pick the best possible CUDA capable device, otherwise
    // override the device ID based on input provided at the command line
    int dev = findCudaDevice(argc, (const char **)argv);

    unsigned int batch = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "BATCH"))
    {
        batch = getCmdLineArgumentInt(argc, (const char **)argv, "BATCH");
    }
    unsigned int head = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "HEAD"))
    {
        head = getCmdLineArgumentInt(argc, (const char **)argv, "HEAD");
    }
    unsigned int block_size = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "BLK"))
    {
        block_size = getCmdLineArgumentInt(argc, (const char **)argv, "BLK");
    }
    unsigned int thread_size = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "THD"))
    {
        thread_size = getCmdLineArgumentInt(argc, (const char **)argv, "THD");
    }

    // load Q
    std::vector<half> mat_Q;
    load_data(mat_Q, "./mat_Q.bin");

    // load K
    std::vector<half> mat_K;
    load_data(mat_K, "./mat_K.bin");

    // load V
    std::vector<half> mat_V;
    load_data(mat_V, "./mat_V.bin");

    // load golden data O
    std::vector<half> gold_mat_O;
    load_data(gold_mat_O, "./gold_mat_O.bin");

    printf("batch %u head %u block_size %u thread_size %u\n", batch, head, block_size, thread_size);

    printf("Q size %lu K size %lu V size %lu O size %lu\n", mat_Q.size(), mat_K.size(), mat_V.size(), gold_mat_O.size());

    checkCudaErrors(cudaProfilerStart());
    int result = flash_attention_test(
        batch, head, block_size, thread_size,
        &mat_Q[0], mat_Q.size(),
        &mat_K[0], mat_K.size(),
        &mat_V[0], mat_V.size(),
        &gold_mat_O[0], gold_mat_O.size());
    checkCudaErrors(cudaProfilerStop());

    exit(result);
}

And I have some follow up questions
i) to use approach 2), is it necessary/the intended way to let the users register the allocate cuda memory functions etc. by the users themselves? E.g. I have written a cudaMallocCallback function myself instead of using cudaMalloc directly to fit in the callback function format, not sure if I should do that.
ii) is there a standard/suggested way to compare the cuda result with some golden data using Hidet? The purpose is for checking whether the scheduling is correct or not.
iii) with the use of Deepview https://github.com/CentML/DeepView.Profile#getting-started (it seems it is developed by your team as well?), can I estimate the performance gain of after using Hidet for compilation on different GPU versions (without the actual chip)?

Thank you very much again for your help 🙏

Hi @keneoneth,

Glad you run the generated lib.so successfully.

For your questions:
i) Here is the dependency graph of hidet generated libraries:
image
hidet_runtime.so is the shared library that will be linked by all hidet generated libraries. It is intentionally designed to avoid the dependency on any backend-specific library (such as cuda driver libcuda.so and cuda runtime library libcudart.so). Instead, it replies on the callback functions to support the backend-specific functionality. For cuda backend, the host system (in the common case, our python interpreter, in your case the C++ host system) should register these callback functions to support the functionality provided by the backend-related libraries. Currently, we do not have the official C++ host system, thus you will find it is hard to use the hidet generated library in C++. We welcome the community contribution on these parts.

ii) In C++, no. In Python, please refer to our tests (https://github.com/hidet-org/hidet/tree/main/tests/operators) to see how we compare the results with numpy/pytorch.

iii) Deepview for now is designed to profile and predict the performance of training and it's built based on pytorch. You need to run the kernel on the target GPU to know the performance. That being said, the idea of Deepview also applied to hidet.