performance of float16 with fast tuning
klxy0304 opened this issue · 0 comments
Hello,
I tried to run a fast tuning of GEMM with float16:
'''
from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy
from bitblas.base.arch import CUDA
from bitblas.base.utils import apply_and_build
import tvm
from tvm.script import tir as T
M = 8
N = 152064
K = 3584
@tvm.script.ir_module
class MatmulNT:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, [M, K], dtype="float16")
B = T.match_buffer(b, [N, K], dtype="float16")
C = T.match_buffer(c, [M, N], dtype="float16")
for i, j, k in T.grid(M, N, K):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = tvm.tir.const(0, "float16")
C[vi, vj] = C[vi, vj] + A[vi, vk].astype("float16") * B[
vj, vk
].astype("float16")
ir_module = MatmulNT
func = ir_module["main"]
target = tvm.target.Target("nvidia/nvidia-a100")
arch = CUDA(target)
Tune with SIMT Cuda Core
policy = DefaultPolicy(func=func, arch=arch)
try:
tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target)
except Exception:
tags = None
Tune with Tensor Core if possible
if tags:
policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)
configs = policy.emit_config(topk=20)
cpresults, best = apply_and_build(func, configs, arch, parallel_build=True)
print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency))
print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency))
'''
But I got results that are not as expected:
[BitBLAS] The best latency of top 1 is 11.767 ms
[BitBLAS] The best latency of top 20 is 5.987 ms
For comparison, I tuned a single-layer model using TVM's Meta Schedule, with the model structure as nn.Linear(3584, 152064) and a batch size of 8. Below are the tuning log results:
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
0 | fused_nn_dense_add | 8721174528 | 1 | 13285.4769 | 656.4442 | 656.4442 | 1535 |
The result is 656 us, I would like to know if I am using the BitBlas tuning method incorrectly.