Error with specific matrix sizes
Closed this issue · 1 comments
mlinmg commented
There is a peculiar problem,
our BMM kernel run fine, unsless the matmul is between specific saped values.
We do masking to prevent illegal memory access but it seems that for certain dimension it always hit an illegal memory access.
here is a simple reproducible error:
import torch
x = torch.randint(-128, 128, (1, 128, 4096), dtype=torch.int8).cuda()
w = torch.randint(-1, 1, [4096*4, 4096], dtype=torch.int8).cuda()
packed_w = pack_ternary(w, 4)
c = batched_bitmat(x, packed_w, 4)
matmul =x.to(torch.float16) @ w.to(torch.float16).t()
assert (c != matmul).sum() == 0
This will not work,
but
w = torch.randint(-1, 1, [4096 * 4 +/-1, 4096], dtype=torch.int8).cuda()
will work.
Here are some variations that make the kernel work
x = torch.randint(-128, 128, (2, 128, 4096), dtype=torch.int8).cuda() #anything with dim=0 > 1
w = torch.randint(-1, 1, [4096 * 4, 4096], dtype=torch.int8).cuda()
w = torch.randint(-1, 1, [4096, 4096], dtype=torch.int8).cuda()
Further investigation is needed
GiacomoLeoneMaria commented
fixed in commit fix issues #7