microsoft/BitBLAS

performance of float16 with fast tuning

klxy0304 opened this issue · 0 comments

Hello,
I tried to run a fast tuning of GEMM with float16:
'''
from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy
from bitblas.base.arch import CUDA
from bitblas.base.utils import apply_and_build
import tvm
from tvm.script import tir as T

M = 8
N = 152064
K = 3584
@tvm.script.ir_module
class MatmulNT:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, [M, K], dtype="float16")
B = T.match_buffer(b, [N, K], dtype="float16")
C = T.match_buffer(c, [M, N], dtype="float16")

    for i, j, k in T.grid(M, N, K):
        with T.block("B"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = tvm.tir.const(0, "float16")
            C[vi, vj] = C[vi, vj] + A[vi, vk].astype("float16") * B[
                vj, vk
            ].astype("float16") 

ir_module = MatmulNT
func = ir_module["main"]
target = tvm.target.Target("nvidia/nvidia-a100")
arch = CUDA(target)

Tune with SIMT Cuda Core

policy = DefaultPolicy(func=func, arch=arch)
try:
tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target)
except Exception:
tags = None

Tune with Tensor Core if possible

if tags:
policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)

configs = policy.emit_config(topk=20)

cpresults, best = apply_and_build(func, configs, arch, parallel_build=True)

print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency))
print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency))
'''

But I got results that are not as expected:
[BitBLAS] The best latency of top 1 is 11.767 ms
[BitBLAS] The best latency of top 20 is 5.987 ms

For comparison, I tuned a single-layer model using TVM's Meta Schedule, with the model structure as nn.Linear(3584, 152064) and a batch size of 8. Below are the tuning log results:

ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done

0 | fused_nn_dense_add | 8721174528 | 1 | 13285.4769 | 656.4442 | 656.4442 | 1535 |

The result is 656 us, I would like to know if I am using the BitBlas tuning method incorrectly.