[BUG] Vectorized Bias Add with AtomicAdd may lead to unknown bugs
Closed this issue · 3 comments
LeiWang1999 commented
#pragma unroll
for (int i_10 = 0; i_10 < 4; ++i_10) {
__syncthreads();
uint4 __1;
uint4 v_ = *(uint4*)(((half_t*)buf_dyn_shmem) + (((i_10 * 1024) + (((int)threadIdx.x) * 8)) + 3072));
uint4 v__1 = *(uint4*)(Bias + (((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) >> 5) * 16)) + ((((int)threadIdx.x) & 1) * 8)));
((half2*)(&(__1.x)))->x = (((half2*)(&(v_.x)))->x+((half2*)(&(v__1.x)))->x);
((half2*)(&(__1.x)))->y = (((half2*)(&(v_.x)))->y+((half2*)(&(v__1.x)))->y);
((half2*)(&(__1.y)))->x = (((half2*)(&(v_.y)))->x+((half2*)(&(v__1.y)))->x);
((half2*)(&(__1.y)))->y = (((half2*)(&(v_.y)))->y+((half2*)(&(v__1.y)))->y);
((half2*)(&(__1.z)))->x = (((half2*)(&(v_.z)))->x+((half2*)(&(v__1.z)))->x);
((half2*)(&(__1.z)))->y = (((half2*)(&(v_.z)))->y+((half2*)(&(v__1.z)))->y);
((half2*)(&(__1.w)))->x = (((half2*)(&(v_.w)))->x+((half2*)(&(v__1.w)))->x);
((half2*)(&(__1.w)))->y = (((half2*)(&(v_.w)))->y+((half2*)(&(v__1.w)))->y);
*(uint4*)(((half_t*)buf_dyn_shmem) + (((i_10 * 1024) + (((int)threadIdx.x) * 8)) + 3072)) = __1;
}
__syncthreads();
#pragma unroll
for (int i_11 = 0; i_11 < 16; ++i_11) {
atomicAddx2((&(C[(((((((int)blockIdx.y) * 65536) + (i_11 * 4096)) + ((((int)threadIdx.x) >> 5) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 31) * 2))])), (&(((half_t*)buf_dyn_shmem)[(((((((i_11 >> 2) * 1024) + (((((int)threadIdx.x) & 31) >> 3) * 256)) + ((i_11 & 3) * 64)) + ((((int)threadIdx.x) >> 5) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 3072)])));
}
have correctness issues while without atomicAdd it's correct.
for (int i_14 = 0; i_14 < 4; ++i_14) {
__syncthreads();
uint4 __1;
uint4 v_ = *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_14 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072));
uint4 v__1 = *(uint4*)(Bias + ((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) & 7) * 8)));
((half2*)(&(__1.x)))->x = (((half2*)(&(v_.x)))->x+((half2*)(&(v__1.x)))->x);
((half2*)(&(__1.x)))->y = (((half2*)(&(v_.x)))->y+((half2*)(&(v__1.x)))->y);
((half2*)(&(__1.y)))->x = (((half2*)(&(v_.y)))->x+((half2*)(&(v__1.y)))->x);
((half2*)(&(__1.y)))->y = (((half2*)(&(v_.y)))->y+((half2*)(&(v__1.y)))->y);
((half2*)(&(__1.z)))->x = (((half2*)(&(v_.z)))->x+((half2*)(&(v__1.z)))->x);
((half2*)(&(__1.z)))->y = (((half2*)(&(v_.z)))->y+((half2*)(&(v__1.z)))->y);
((half2*)(&(__1.w)))->x = (((half2*)(&(v_.w)))->x+((half2*)(&(v__1.w)))->x);
((half2*)(&(__1.w)))->y = (((half2*)(&(v_.w)))->y+((half2*)(&(v__1.w)))->y);
*(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_14 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072)) = __1;
}
__syncthreads();
#pragma unroll
for (int i_15 = 0; i_15 < 4; ++i_15) {
*(uint4*)(C + (((((((int)blockIdx.y) * 65536) + (i_15 * 16384)) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8))) = *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_15 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072));
}
currently we disable atomicAdd when we have bias to skip this situation.
Reproduce:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bitblas
import bitblas.testing
from bitblas import Linear as BitBLASLinear
import torch
import time
import numpy as np
import torch.nn as nn
torch.manual_seed(0)
bitblas.set_log_level("DEBUG")
def correctness_consistent(m, in_features, out_features, bias):
linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda())
linear_bitblas = BitBLASLinear(
in_features,
out_features,
bias=bias,
A_dtype="float16",
W_dtype="float16",
accum_dtype="float16",
out_dtype="float16",
opt_M=m,
).cuda()
with torch.no_grad():
linear_bitblas.load_and_transform_weight(linear_torch.weight.clone())
if bias:
linear_bitblas.bias = nn.Parameter(linear_torch.bias.clone())
with torch.no_grad():
if not isinstance(m, int):
# When m is a list, average m
m = sum(m) // len(m)
input_data = torch.randn(m, in_features, dtype=torch.float16).cuda()
output_torch = linear_torch(input_data)
output_bitblas = linear_bitblas(input_data)
print(output_torch)
print(output_bitblas)
bitblas.testing.torch_assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2)
def test_correctness_consistent():
correctness_consistent(1, 1024, 1024, False)
correctness_consistent(1, 1024, 1024, True)
correctness_consistent(1024, 1024, 1024, True)
correctness_consistent([1, 1024], 1024, 1024, True)
def correctness_weight_only_dequantize(
m,
in_features,
out_features,
bias,
W_dtype,
group_size,
with_scaling,
with_zeros,
zeros_mode,
):
import numpy as np
from bitblas.quantization.utils import general_compress
from bitblas.cache import global_operator_cache
global_operator_cache.clear()
linear_bitblas = BitBLASLinear(
in_features,
out_features,
bias=bias,
A_dtype="float16",
W_dtype=W_dtype,
accum_dtype="float16",
out_dtype="float16",
group_size=group_size,
with_scaling=with_scaling,
with_zeros=with_zeros,
opt_M=m,
).cuda()
if not isinstance(m, int):
# average m
m = sum(m) // len(m)
input_shape = (m, in_features)
weight_shape = (out_features, in_features)
output_shape = (m, out_features)
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
source_format, bit = (
linear_bitblas.bitblas_matmul.source_format,
linear_bitblas.bitblas_matmul.bit,
)
maxq = 2**(bit - 1)
zeros = maxq
if source_format == "uint":
inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
elif source_format == "int":
inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda())
else:
raise NotImplementedError
inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())
intweight = inputs[1]
intweight = intweight.cpu().to(torch.int8)
if source_format == "int":
intweight = intweight + maxq
if with_zeros:
inputs[1] = inputs[1] - zeros
bias_tensor = torch.rand((output_shape[-1],), dtype=torch.float16).cuda()
ref_result = torch.matmul(inputs[0], (inputs[1].t()).to(torch.float16))
if bias:
ref_result = ref_result + bias_tensor
with torch.no_grad():
permuted_inputs = []
permuted_inputs.append(inputs[0])
if linear_bitblas.bitblas_matmul.weight_transform is not None:
permuted_inputs.append(
linear_bitblas.bitblas_matmul.weight_transform(intweight.cpu()).cuda())
else:
permuted_inputs.append(inputs[1])
linear_bitblas.qweight.data = permuted_inputs[-1].clone()
if with_scaling:
if group_size == -1:
group_size = in_features
permuted_inputs.append(
torch.ones([out_features, in_features // group_size], dtype=torch.float16).cuda())
linear_bitblas.scales.data = permuted_inputs[-1].clone()
if with_zeros:
if zeros_mode == "original":
permuted_inputs.append(
torch.ones([out_features, in_features // group_size],
dtype=torch.float16).cuda() * zeros)
elif zeros_mode == "rescale":
original_zeros = (
torch.ones([out_features, in_features // group_size],
dtype=torch.float16).cuda() * zeros)
scaled_zeros = original_zeros * permuted_inputs[-1]
permuted_inputs.append(scaled_zeros)
elif zeros_mode == "quantized":
original_zeros = (
torch.ones([in_features // group_size, out_features], dtype=torch.int8).cuda() *
zeros)
qzeros = general_compress(
original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
permuted_inputs.append(torch.from_numpy(qzeros).cuda())
else:
raise NotImplementedError
linear_bitblas.zeros.data = permuted_inputs[-1].clone()
if bias:
permuted_inputs.append(bias_tensor)
linear_bitblas.bias.data = bias_tensor.clone()
with torch.no_grad():
output_bitblas = linear_bitblas(inputs[0])
rtol = 1e0
atol = 1e0
if zeros_mode == "original":
rtol = 1e2
atol = 1e2
print(output_bitblas)
print(ref_result)
torch.testing.assert_close(output_bitblas, ref_result, rtol=rtol, atol=atol)
def test_correctness_weight_only_dequantize():
correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original")
correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original")
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale")
def profile(model, input_data):
model = model.cuda()
model.eval()
def get_runtime(num_repeats=1):
tic = time.time()
for _ in range(num_repeats):
_ = model(input_data)
torch.cuda.synchronize()
return (time.time() - tic) * 1000 / num_repeats
with torch.no_grad():
# print("Warming up ...")
st = time.time()
while time.time() - st < 1.0:
get_runtime() # warmup
warmup_runtime = get_runtime()
num_repeats = max(1, int(1000 / warmup_runtime))
times = get_runtime(num_repeats)
return np.mean(times)
if __name__ == "__main__":
# bitblas.testing.main()
correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)
LeiWang1999 commented
More clean script to help reproduce:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
import bitblas.testing
from tvm import tl
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
MatmulBlockScheduler,
MatmulFineGrainScheduler,
MatmulWeightPropagationScheduler,
)
from bitblas.ops.general_matmul.tilelang.dequantize import (
MatmulDequantizeScheduler,
MatmulDequantizeFineGrainedScheduler,
MatmulDequantizeWeightPropagationScheduler,
MatmulINT4DequantizeFineGrainedScheduler,
MatmulINT4DequantizeWeightPropagationScheduler,
)
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
MatmulINT4FineGrainScheduler,
MatmulINT4WeightPropagationScheduler,
)
import torch
import torch.backends
torch.manual_seed(0)
verbose = False
def assert_matmul_fine_grained_dequant_with_default_correctness(
M,
N,
K,
trans_A=False,
trans_B=True,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
bit=4,
storage_dtype="int8",
source_format="uint",
with_scaling=False,
with_zeros=False,
group_size=-1,
fast_decoding=False,
zeros_mode="original",
with_bias=False,
split_k_factor=1,
):
import numpy as np
from bitblas.quantization import general_compress, interleave_weight
matmul = MatmulDequantizeFineGrainedScheduler(
M=M,
N=N,
K=K,
trans_A=trans_A,
trans_B=trans_B,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
num_bits=bit,
storage_dtype=storage_dtype,
source_format=source_format,
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
fast_decoding=fast_decoding,
zeros_mode=zeros_mode,
with_bias=with_bias,
).apply_config(
block_row_warps=2,
block_col_warps=2,
warp_row_tiles=32,
warp_col_tiles=32,
chunk=32,
num_stages=0,
enable_rasterization=False,
split_k_factor=split_k_factor,
)
mod, params = tl.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
input_shape = (M, K)
weight_shape = (N, K)
output_shape = (M, N)
bias_shape = (N, )
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
maxq = 2**(bit - 1)
zeros = maxq
if source_format == "uint":
inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
elif source_format == "int":
inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda())
else:
raise NotImplementedError
bias = torch.ones(bias_shape, dtype=torch.float16).cuda()
inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())
intweight = inputs[1]
intweight = intweight.cpu().to(torch.int8)
if source_format == "int":
intweight = intweight + maxq
if with_zeros:
inputs[1] = inputs[1] - zeros
permuted_inputs = []
permuted_inputs.append(inputs[0])
qw = general_compress(intweight.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
# lop3 transformation
if fast_decoding:
qw = interleave_weight(qw, bit, target_dtype=in_dtype)
permuted_inputs.append(torch.from_numpy(qw).cuda())
if with_scaling:
if group_size == -1:
group_size = K
permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda())
if with_zeros:
if zeros_mode == "original":
permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda())
elif zeros_mode == "rescale":
original_zeros = (torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros)
scaled_zeros = original_zeros * permuted_inputs[-1]
permuted_inputs.append(scaled_zeros)
elif zeros_mode == "quantized":
original_zeros = (torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros)
qzeros = general_compress(
original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
permuted_inputs.append(torch.from_numpy(qzeros).cuda())
else:
raise NotImplementedError
if with_bias:
permuted_inputs.append(bias)
permuted_inputs.append(inputs[2])
mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)
mod(*permuted_inputs)
print(permuted_inputs[-1])
args = [inputs[0]]
b = inputs[1]
if with_scaling:
scale = permuted_inputs[2]
rescale_b = torch.empty_like(b, dtype=torch.float16)
for i in range(N):
for j in range(K):
if with_zeros:
if zeros_mode == "original":
rescale_b[i, j] = (b[i, j] - zeros) * scale[i, j // group_size]
elif zeros_mode == "rescale":
rescale_b[i, j] = (b[i, j] * scale[i, j // group_size] + zeros)
else:
raise NotImplementedError
else:
rescale_b[i, j] = b[i, j] * scale[i, j // group_size]
args.append(rescale_b.t().cuda())
else:
args.append(b.t().cuda().to(torch.float16))
ref_result = torch.matmul(*args)
if with_bias:
ref_result = ref_result + bias
print(ref_result)
bitblas.testing.torch_assert_close(permuted_inputs[-1], ref_result, rtol=1e-1, atol=1e-1)
def test_matmul_fine_grained_dequant_with_default():
assert_matmul_fine_grained_dequant_with_default_correctness(
1024, 1024, 1024, source_format="uint", bit=4)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024, 1024, 1024, source_format="uint", bit=2)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024,
1024,
1024,
source_format="uint",
bit=4,
with_scaling=True,
with_zeros=True,
)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024,
1024,
1024,
source_format="uint",
bit=4,
with_scaling=True,
fast_decoding=True,
)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024,
1024,
1024,
source_format="uint",
bit=4,
with_scaling=True,
with_zeros=True,
fast_decoding=True,
)
if __name__ == "__main__":
# non-splitk + non-bias
assert_matmul_fine_grained_dequant_with_default_correctness(
128, 128, 128, source_format="int", bit=4, with_bias=False, split_k_factor=1)
# non-splitk + bias
assert_matmul_fine_grained_dequant_with_default_correctness(
128, 128, 128, source_format="int", bit=4, with_bias=True, split_k_factor=1)
# atomicAdd + non-bias
assert_matmul_fine_grained_dequant_with_default_correctness(
128, 128, 128, source_format="int", bit=4, with_bias=False, split_k_factor=2)
# atomicAdd + bias
assert_matmul_fine_grained_dequant_with_default_correctness(
1024, 1024, 1024, source_format="uint", bit=4, with_bias=True, split_k_factor=2)
LeiWang1999 commented
Interesting bug and resolved, as blockIdx.z represents k-dimension during splitk implementation, so bias add must be done in only one blockZDim, otherwise bias will be added multiple times.
LeiWang1999 commented
Closed as be resolved in pr #270