pku-liang/AMOS

Fail to run gemm_bench with Problem size (2,2,2) (4,4,4) (8,8,8)

Opened this issue · 0 comments

Hi there, I'm currently benchmarking gemm perfomance of amos on tensorcore, I modified the mapping_gemm_tensorcore as below:

import tvm
import os
from tvm import auto_tensorize as at
import argparse


def gemm(M, N, K, in_dtype, out_dtype):
    A = tvm.te.placeholder([M, K], dtype=in_dtype, name="A")
    B = tvm.te.placeholder([K, N], dtype=in_dtype, name="B")

    rk = tvm.te.reduce_axis([0, K], name="k")
    C = tvm.te.compute(
        [M, N], lambda i, j: tvm.te.sum((A[i, rk] * B[rk, j]).astype(out_dtype), axis=rk), name="C"
    )
    return [A, B, C]


def mapping_tensorcore(
    M,
    N,
    K,
    layer,
    in_dtype,
    out_dtype,
    simple_mode=True,
    trials=-1,
    verbose=False,
    use_perf_model=False,
    perf_model_ratio=0.6,
):
    A, B, Gemm = gemm(M, N, K, in_dtype, out_dtype)
    target_dag = at.compute_dag_from_tensors([Gemm])
    target = "cuda"

    log_dir = "gemm-%s-%s-layer-%s" % (in_dtype, out_dtype, layer)
    log_file = "gemm-%s-%s-layer-%s.log" % (in_dtype, out_dtype, layer)

    measure_opt = at.MeasureOptions(target=target, timeout=100, number=200, min_repeat_ms=500)

    if simple_mode:
        trials = 1000 if trials < 0 else trials
        result = at.auto_tensorize(
            target_dag, target, log_file, measure_opt, trials=trials, verbose=verbose
        )
        if not result.defined():
            print("Can't do tensorize.")
            return
        schedule_gen = result.sch_gen
        schedule_app = result.sch_app

        # load from file
        schedule_gen.load_from_file(log_file, clear=True)
        entry = schedule_gen.get_best_entry()
        # we store 1/time_cost in file
        params, value = entry.record, 1 / entry.value
        print(value)
        print(params.to_json())
    else:
        trials = 4000 if trials < 0 else trials
        result = at.auto_tensorize_v4(
            target_dag,
            target,
            log_file,
            measure_opt,
            schedule_log_dir=log_dir,
            trials=trials,
            search_group_size=5,
            transform_dump=verbose,
            enable_perf_model=use_perf_model,
            perf_percentage=perf_model_ratio,
        )
        if not result.defined():
            print("Can't do tensorize.")
            return
        schedule_gen = result.sch_gen
        schedule_app = result.sch_app

        # we store 1/time_cost in file
        params, value = result.params, result.perf
        print(value)
        print(params.to_json())

    cost = at.evaluate_params(schedule_app, params, measure_opt, dump=verbose)
    print("Cost of %s is %f ms" % (log_dir, cost))
    return cost


shapes = [(16, 512, 128), (1024, 16, 256), (256, 1024, 256), (512, 256, 16), (1024, 1024, 1024)]

supported_dtypes = set(
    [
        ("float16", "float16"),
        ("float16", "float32"),
        ("bfloat16", "float32"),
        ("float32", "float32"),
        ("float64", "float64"),
        ("int4", "int32"),
        ("int8", "int32"),
    ]
)

example_text = """
 example:
    python mapping_gemm_tensorcore.py --in_dtype float16 --out_dtype float16 --begin 0 --num 1 --trials 20
    python mapping_gemm_tensorcore.py --in_dtype float16 --out_dtype float32 --begin 0 --num 1 --trials 20
    python mapping_gemm_tensorcore.py --in_dtype float32 --out_dtype float32 --begin 0 --num 1 --trials 20
    python mapping_gemm_tensorcore.py --in_dtype float16 --out_dtype float16 --begin 0 --num 1 --trials 400 --simple_mode 0
"""

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="base_maker",
        description="template maker",
        epilog=example_text,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--in_dtype",
        type=str,
        choices=["float16", "float32", "float64", "bfloat16", "int4", "int8"],
        default="float16",
    )
    parser.add_argument(
        "--out_dtype",
        type=str,
        choices=["float16", "float32", "float64", "int32"],
        default="float16",
    )
    parser.add_argument("--begin", type=int, choices=list(range(len(shapes))), default=0)
    parser.add_argument(
        "--num", type=int, choices=list(range(1, len(shapes) + 1)), default=len(shapes)
    )
    parser.add_argument("--simple_mode", type=int, default=1, choices=[0, 1])
    parser.add_argument("--trials", type=int, default=-1)
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--use_perf_model", action="store_true")
    parser.add_argument("--perf_model_ratio", type=float, default=0.6)

    args = parser.parse_args()
    assert 0 < args.perf_model_ratio <= 1.0
    if args.use_perf_model:
        assert args.simple_mode == 0, "Performance model is only supported without simple_mode"
    beg = args.begin
    num = args.num
    print(args.simple_mode)
    assert (
        args.in_dtype,
        args.out_dtype,
    ) in supported_dtypes, (
        f"The desired dtype pair {(args.in_dtype, args.out_dtype)} is not supported by Tensor Core."
    )
    costs = []
    for i, shape in enumerate(shapes[beg : beg + num]):
        (M, N, K) = shape
        print("\n\nProblem size:")
        print(M, N, K)
        layer_name = f"({M}, {N}, {K})"
        try:
            cost = mapping_tensorcore(
                M,
                N,
                K,
                layer_name,
                args.in_dtype,
                args.out_dtype,
                simple_mode=args.simple_mode,
                trials=args.trials,
                verbose=args.verbose,
                use_perf_model=args.use_perf_model,
                perf_model_ratio=args.perf_model_ratio,
            )
            costs.append(cost)
        except Exception as e:
            print("Fail to run\n", str(e))
            costs.append(float("inf"))
    for cost in costs:
        print(cost)

just with shaped customized, but this bench throw a Fail to run error.

0


Problem size:
2 2 2
Possible matchings:
0 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:16x16x16)
1 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:32x8x16)
2 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:8x32x16)
3 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:16x16x16)
4 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:32x8x16)
5 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:8x32x16)
6 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:16x16x16)
7 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:32x8x16)
8 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:8x32x16)
9 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:16x16x16)
10 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:32x8x16)
11 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:8x32x16)
Logging to devnull...
Totally 1 different mappings for this matching
Logging to devnull...
Totally 1 different mappings for this matching
Catch an infeasible mapping:
{"vmap": [[1], -1]}
Fail to run



Problem size:
4 4 4
Possible matchings:
0 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:16x16x16)
1 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:32x8x16)
2 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:8x32x16)
3 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:16x16x16)
4 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:32x8x16)
5 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:8x32x16)
6 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:16x16x16)
7 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:32x8x16)
8 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:8x32x16)
9 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:16x16x16)
10 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:32x8x16)
11 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:8x32x16)
Logging to devnull...
Totally 1 different mappings for this matching
Logging to devnull...
Totally 1 different mappings for this matching
Catch an infeasible mapping:
{"vmap": [[1], -1]}
Fail to run



Problem size:
8 8 8
Possible matchings:
0 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:16x16x16)
1 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:32x8x16)
2 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:8x32x16)
3 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:16x16x16)
4 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:32x8x16)
5 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:8x32x16)
6 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:16x16x16)
7 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:32x8x16)
8 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:8x32x16)
9 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:16x16x16)
10 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:32x8x16)
11 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:8x32x16)
Logging to devnull...
Totally 1 different mappings for this matching
Logging to devnull...
Totally 1 different mappings for this matching
Catch an infeasible mapping:
{"vmap": [[1], -1]}
Fail to run

but (16, 16, 16) can do well, any suggestions?