Post-Scheduling Fusion with TensorCore
SunflowerAries opened this issue · 2 comments
Hello, I just read your Hidet paper, and it looks pretty powerful, but I have few questions.
Does Hidet support codegen for TensorCore right now? And can it codegen for fused operators like batchmatmul+add+reshape+transpose? This fused operator comes from bert and I want to fuse them into one complex operators to speedup the network execution.
below is my test script, although I've set the precision and mma flag, I do not find wmma instructions in generated cuda code. How can I run this operator on TensorCore? Hope for your help, Thanks
import hidet
# change the cache directory
hidet.option.cache_dir('./outs/cache')
# save the tensor program level ir in operator cache
hidet.option.save_lower_ir()
def main():
# construct a simple graph
x = hidet.symbol([16, 256, 512], device='cuda')
w = hidet.randn([16, 512, 512], device='cuda')
b = hidet.randn([512], device='cuda')
x = hidet.ops.batch_matmul(x, w)
x = x + b
x = hidet.ops.reshape(x, [16, 256, 8, 64])
x = hidet.ops.transpose(x, [0, 2, 1, 3])
# x = hidet.ops.pad(x, [3, 3, 3, 3])
# x = hidet.ops.conv2d(x, w, stride=2)
# x = hidet.ops.relu(x)
graph = hidet.trace_from(x)
print(graph)
# graph optimizations
with hidet.graph.PassContext() as ctx:
# save the computation graph level ir
ctx.save_graph_instrument(out_dir='./outs/graphs')
ctx.set_precision(dtype='float16')
ctx.set_reduce_precision(dtype='float32')
ctx.set_mma('mma')
graph_opt = hidet.graph.optimize(graph)
# run the optimized graph
xx = hidet.randn([16, 256, 512], device='cuda')
yy = graph_opt(xx)
if __name__ == '__main__':
main()
Hi @SunflowerAries,
Thanks for your interest in hidet.
Short answer: use hidet.ops.matmul
instead of hidet.ops.batch_matmul
:
import hidet
# change the cache directory
hidet.option.cache_dir('./outs/cache')
# save the tensor program level ir in operator cache
hidet.option.save_lower_ir()
def main():
# construct a simple graph
x = hidet.symbol([16, 256, 512], device='cuda')
w = hidet.randn([16, 512, 512], device='cuda')
b = hidet.randn([512], device='cuda')
x = hidet.ops.matmul(x, w)
x = x + b
x = hidet.ops.reshape(x, [16, 256, 8, 64])
x = hidet.ops.transpose(x, [0, 2, 1, 3])
# x = hidet.ops.pad(x, [3, 3, 3, 3])
# x = hidet.ops.conv2d(x, w, stride=2)
# x = hidet.ops.relu(x)
graph = hidet.trace_from(x)
print(graph)
# graph optimizations
with hidet.graph.PassContext() as ctx:
# save the computation graph level ir
ctx.save_graph_instrument(out_dir='./outs/graphs')
ctx.set_precision(dtype='float16')
ctx.set_reduce_precision(dtype='float32')
ctx.set_mma('mma')
graph_opt = hidet.graph.optimize(graph)
# run the optimized graph
xx = hidet.randn([16, 256, 512], device='cuda')
yy = graph_opt(xx)
if __name__ == '__main__':
main()
The hidet.ops.matmul
will be resolve to hidet.ops.batch_matmul
(which uses cuda core and support all kinds of data types) or matmul_fp16_pk
(defined in hidet.ops.matmul.matmul_fp16
, which uses tensor core) during optimization.
OK, thanks for your help. I successfully run this fused kernels on TensorCore.