microsoft/BitBLAS

Any example to use `float16xfp4_e2m1` matmul?

Closed this issue · 2 comments

Hi @LeiWang1999,

I encounter the following error when I try to build a matmul with a_dtype = 'float16' and b_dtype = 'fp4_e2m1' with bitblas, however I encounter the following error:

tvm.error.InternalError: Traceback (most recent call last):
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::PrimExpr (tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::PrimExpr (*)(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span)>(tvm::PrimExpr (*)(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::reinterpret(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span)
  File "/root/BitBLAS/3rdparty/tvm/src/tir/op/op.cc", line 415
InternalError: Check failed: (value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) is false: Bitcast requires size match float16 vs uint32

which caused by the following python function

def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
    assert nbit == 4
    assert dtype == "float16"
    assert val.dtype == "uint32"
    # e_f4 == 0 -> e_f16 = 0
    # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
    mask = tvm.tir.const((1 << nbit) - 1, "uint32")
    f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask
    s = f4 >> tir.const(3, "uint32")
    e_f4 = f4 & tir.const(7, "uint32")
    e_f16 = e_f4 | tir.const(8, "uint32")
    val_f16 = tir.reinterpret("float16",
                              (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32"))
    return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)

At this line

    val_f16 = tir.reinterpret("float16",
                              (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32"))

seems bitblas is trying to reinterpret a "uint32" to "float16", making tvm complains.

I am using the following matmul config.

        config = MatmulConfig(
            M=4096,
            N=4096,
            K=4096,
            A_dtype="float16",
            W_dtype="fp4_e2m1",
            out_dtype="float16",
            group_size=128,
            accum_dtype='float32',
            with_scaling=True,
            with_zeros=True,
            zeros_mode='original',
            storage_dtype="uint32"
        )

Do you have any example to run the fp16xfp4 matmul in bitblas?

(I am using the v0.0.1.dev15 version)

@yaoyaoding Thanks for your reporting, yeah I think it should be:

def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
    assert nbit == 4
    assert dtype == "float16"
    assert val.dtype == "uint32"
    # e_f4 == 0 -> e_f16 = 0
    # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
    mask = tvm.tir.const((1 << nbit) - 1, "uint16")
    f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
    s = f4 >> tir.const(3, "uint16")
    e_f4 = f4 & tir.const(7, "uint16")
    e_f16 = e_f4 | tir.const(8, "uint16")
    val_f16 = tir.reinterpret("float16",
                              ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
    return tir.Select(e_f4 == tir.const(0, "uint16"), tir.const(0, "float16"), val_f16)

Ad we should extend to a general cast to make it compatible with any storage dtypes.

During benchmarking, we utilize nf4 for our fp4 benchmarking.

Got it, thanks @LeiWang1999 for the timely reponse and fix!