astramind-ai/BitMat

Error with specific matrix sizes

Closed this issue · 1 comments

There is a peculiar problem,
our BMM kernel run fine, unsless the matmul is between specific saped values.


We do masking to prevent illegal memory access but it seems that for certain dimension it always hit an illegal memory access.
here is a simple reproducible error:

import torch

x = torch.randint(-128, 128, (1, 128, 4096), dtype=torch.int8).cuda() 
w = torch.randint(-1, 1, [4096*4, 4096], dtype=torch.int8).cuda()

packed_w = pack_ternary(w, 4)
c = batched_bitmat(x, packed_w, 4)
matmul =x.to(torch.float16) @  w.to(torch.float16).t()
assert (c != matmul).sum() == 0

This will not work,
but

w = torch.randint(-1, 1, [4096 * 4 +/-1, 4096], dtype=torch.int8).cuda()

will work.
Here are some variations that make the kernel work

x = torch.randint(-128, 128, (2, 128, 4096), dtype=torch.int8).cuda()  #anything with dim=0 > 1
w = torch.randint(-1, 1, [4096 * 4, 4096], dtype=torch.int8).cuda()
w = torch.randint(-1, 1, [4096, 4096], dtype=torch.int8).cuda()

Further investigation is needed