Kernel readability and correctness
Closed this issue · 1 comments
michaelfeil commented
Trying to reproduce your results.
This annotation below makes a lot of sense.
@triton.jit
def _ternary_mm_kernel(
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), float
B has shape (K//n_bits, N), int, packed boolean
Question: packed float tensor ? float tensor (M, K // n_bits)
I assume you switched around the datatypes? But then the dimensions don't make sense.
def bitmat(a, b, int_per_2_bits=4, activation=""):
"""
a: float tensor (M, K // n_bits)
b: int tensor (K, N)
n_bits: int, number of bits that each element in b represents
"""
Lastly, pack_ternary
currently packs into K, N // n_element_in_one_int
, but seems like the above kernel want's K // n_element_in_one_int, N
def pack_ternary(x, n_element_in_one_int=4):
"""
Pack ternary values into integers.
x: tensor of shape (*, K, N)
n_element_in_one_int: int, number of elements in one integer
return: tensor of shape (*, K, N // n_element_in_one_int)
"""
mlinmg commented
Yes this is an old version, you're rigth, I'll push a fix