Any example to use `float16xfp4_e2m1` matmul?
Closed this issue · 2 comments
Hi @LeiWang1999,
I encounter the following error when I try to build a matmul with a_dtype = 'float16' and b_dtype = 'fp4_e2m1' with bitblas, however I encounter the following error:
tvm.error.InternalError: Traceback (most recent call last):
1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::PrimExpr (tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::PrimExpr (*)(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span)>(tvm::PrimExpr (*)(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
0: tvm::reinterpret(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span)
File "/root/BitBLAS/3rdparty/tvm/src/tir/op/op.cc", line 415
InternalError: Check failed: (value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) is false: Bitcast requires size match float16 vs uint32
which caused by the following python function
def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint32"
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
mask = tvm.tir.const((1 << nbit) - 1, "uint32")
f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask
s = f4 >> tir.const(3, "uint32")
e_f4 = f4 & tir.const(7, "uint32")
e_f16 = e_f4 | tir.const(8, "uint32")
val_f16 = tir.reinterpret("float16",
(e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32"))
return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
At this line
val_f16 = tir.reinterpret("float16",
(e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32"))
seems bitblas is trying to reinterpret a "uint32" to "float16", making tvm complains.
I am using the following matmul config.
config = MatmulConfig(
M=4096,
N=4096,
K=4096,
A_dtype="float16",
W_dtype="fp4_e2m1",
out_dtype="float16",
group_size=128,
accum_dtype='float32',
with_scaling=True,
with_zeros=True,
zeros_mode='original',
storage_dtype="uint32"
)
Do you have any example to run the fp16xfp4 matmul in bitblas?
(I am using the v0.0.1.dev15
version)
@yaoyaoding Thanks for your reporting, yeah I think it should be:
def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint32"
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
mask = tvm.tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
return tir.Select(e_f4 == tir.const(0, "uint16"), tir.const(0, "float16"), val_f16)
Ad we should extend to a general cast to make it compatible with any storage dtypes.
During benchmarking, we utilize nf4
for our fp4 benchmarking.
Got it, thanks @LeiWang1999 for the timely reponse and fix!