hidet-org/hidet

Post-Scheduling Fusion with TensorCore

SunflowerAries opened this issue · 2 comments

Hello, I just read your Hidet paper, and it looks pretty powerful, but I have few questions.

Does Hidet support codegen for TensorCore right now? And can it codegen for fused operators like batchmatmul+add+reshape+transpose? This fused operator comes from bert and I want to fuse them into one complex operators to speedup the network execution.

below is my test script, although I've set the precision and mma flag, I do not find wmma instructions in generated cuda code. How can I run this operator on TensorCore? Hope for your help, Thanks

import hidet

# change the cache directory
hidet.option.cache_dir('./outs/cache')

# save the tensor program level ir in operator cache
hidet.option.save_lower_ir()


def main():
    # construct a simple graph
    x = hidet.symbol([16, 256, 512], device='cuda')
    w = hidet.randn([16, 512, 512], device='cuda')
    b = hidet.randn([512], device='cuda')
    x = hidet.ops.batch_matmul(x, w)
    x = x + b
    x = hidet.ops.reshape(x, [16, 256, 8, 64])
    x = hidet.ops.transpose(x, [0, 2, 1, 3])
    
    # x = hidet.ops.pad(x, [3, 3, 3, 3])
    # x = hidet.ops.conv2d(x, w, stride=2)
    # x = hidet.ops.relu(x)
    
    graph = hidet.trace_from(x)
    print(graph)

    # graph optimizations
    with hidet.graph.PassContext() as ctx:
        # save the computation graph level ir
        ctx.save_graph_instrument(out_dir='./outs/graphs')
        ctx.set_precision(dtype='float16')
        ctx.set_reduce_precision(dtype='float32')
        ctx.set_mma('mma')
        graph_opt = hidet.graph.optimize(graph)

    # run the optimized graph
    xx = hidet.randn([16, 256, 512], device='cuda')
    yy = graph_opt(xx)


if __name__ == '__main__':
    main()

Hi @SunflowerAries,

Thanks for your interest in hidet.

Short answer: use hidet.ops.matmul instead of hidet.ops.batch_matmul:

import hidet

# change the cache directory
hidet.option.cache_dir('./outs/cache')

# save the tensor program level ir in operator cache
hidet.option.save_lower_ir()


def main():
    # construct a simple graph
    x = hidet.symbol([16, 256, 512], device='cuda')
    w = hidet.randn([16, 512, 512], device='cuda')
    b = hidet.randn([512], device='cuda')
    x = hidet.ops.matmul(x, w)
    x = x + b
    x = hidet.ops.reshape(x, [16, 256, 8, 64])
    x = hidet.ops.transpose(x, [0, 2, 1, 3])
    
    # x = hidet.ops.pad(x, [3, 3, 3, 3])
    # x = hidet.ops.conv2d(x, w, stride=2)
    # x = hidet.ops.relu(x)
    
    graph = hidet.trace_from(x)
    print(graph)

    # graph optimizations
    with hidet.graph.PassContext() as ctx:
        # save the computation graph level ir
        ctx.save_graph_instrument(out_dir='./outs/graphs')
        ctx.set_precision(dtype='float16')
        ctx.set_reduce_precision(dtype='float32')
        ctx.set_mma('mma')
        graph_opt = hidet.graph.optimize(graph)

    # run the optimized graph
    xx = hidet.randn([16, 256, 512], device='cuda')
    yy = graph_opt(xx)


if __name__ == '__main__':
    main()

The hidet.ops.matmul will be resolve to hidet.ops.batch_matmul (which uses cuda core and support all kinds of data types) or matmul_fp16_pk (defined in hidet.ops.matmul.matmul_fp16, which uses tensor core) during optimization.

OK, thanks for your help. I successfully run this fused kernels on TensorCore.