astramind-ai/BitMat

Kernel readability and correctness

Closed this issue · 1 comments

Trying to reproduce your results.

This annotation below makes a lot of sense.

@triton.jit
def _ternary_mm_kernel(

    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), float
    B has shape (K//n_bits, N), int, packed boolean

Question: packed float tensor ? float tensor (M, K // n_bits)
I assume you switched around the datatypes? But then the dimensions don't make sense.

def bitmat(a, b, int_per_2_bits=4, activation=""):
    """
        a: float tensor (M, K // n_bits)
        b: int tensor (K, N)
        n_bits: int, number of bits that each element in b represents
    """

Lastly, pack_ternary currently packs into K, N // n_element_in_one_int, but seems like the above kernel want's K // n_element_in_one_int, N

def pack_ternary(x, n_element_in_one_int=4):
    """
    Pack ternary values into integers.
    x: tensor of shape (*, K, N)
    n_element_in_one_int: int, number of elements in one integer
    return: tensor of shape (*, K, N // n_element_in_one_int)
    """

Yes this is an old version, you're rigth, I'll push a fix