microsoft/BitBLAS

[BUG] Vectorized Bias Add with AtomicAdd may lead to unknown bugs

Closed this issue · 3 comments

  #pragma unroll
  for (int i_10 = 0; i_10 < 4; ++i_10) {
    __syncthreads();
    uint4 __1;
      uint4 v_ = *(uint4*)(((half_t*)buf_dyn_shmem) + (((i_10 * 1024) + (((int)threadIdx.x) * 8)) + 3072));
      uint4 v__1 = *(uint4*)(Bias + (((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) >> 5) * 16)) + ((((int)threadIdx.x) & 1) * 8)));
      ((half2*)(&(__1.x)))->x = (((half2*)(&(v_.x)))->x+((half2*)(&(v__1.x)))->x);
      ((half2*)(&(__1.x)))->y = (((half2*)(&(v_.x)))->y+((half2*)(&(v__1.x)))->y);
      ((half2*)(&(__1.y)))->x = (((half2*)(&(v_.y)))->x+((half2*)(&(v__1.y)))->x);
      ((half2*)(&(__1.y)))->y = (((half2*)(&(v_.y)))->y+((half2*)(&(v__1.y)))->y);
      ((half2*)(&(__1.z)))->x = (((half2*)(&(v_.z)))->x+((half2*)(&(v__1.z)))->x);
      ((half2*)(&(__1.z)))->y = (((half2*)(&(v_.z)))->y+((half2*)(&(v__1.z)))->y);
      ((half2*)(&(__1.w)))->x = (((half2*)(&(v_.w)))->x+((half2*)(&(v__1.w)))->x);
      ((half2*)(&(__1.w)))->y = (((half2*)(&(v_.w)))->y+((half2*)(&(v__1.w)))->y);
    *(uint4*)(((half_t*)buf_dyn_shmem) + (((i_10 * 1024) + (((int)threadIdx.x) * 8)) + 3072)) = __1;
  }
  __syncthreads();
  #pragma unroll
  for (int i_11 = 0; i_11 < 16; ++i_11) {
    atomicAddx2((&(C[(((((((int)blockIdx.y) * 65536) + (i_11 * 4096)) + ((((int)threadIdx.x) >> 5) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 31) * 2))])), (&(((half_t*)buf_dyn_shmem)[(((((((i_11 >> 2) * 1024) + (((((int)threadIdx.x) & 31) >> 3) * 256)) + ((i_11 & 3) * 64)) + ((((int)threadIdx.x) >> 5) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 3072)])));
  }

have correctness issues while without atomicAdd it's correct.

  for (int i_14 = 0; i_14 < 4; ++i_14) {
    __syncthreads();
    uint4 __1;
      uint4 v_ = *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_14 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072));
      uint4 v__1 = *(uint4*)(Bias + ((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) & 7) * 8)));
      ((half2*)(&(__1.x)))->x = (((half2*)(&(v_.x)))->x+((half2*)(&(v__1.x)))->x);
      ((half2*)(&(__1.x)))->y = (((half2*)(&(v_.x)))->y+((half2*)(&(v__1.x)))->y);
      ((half2*)(&(__1.y)))->x = (((half2*)(&(v_.y)))->x+((half2*)(&(v__1.y)))->x);
      ((half2*)(&(__1.y)))->y = (((half2*)(&(v_.y)))->y+((half2*)(&(v__1.y)))->y);
      ((half2*)(&(__1.z)))->x = (((half2*)(&(v_.z)))->x+((half2*)(&(v__1.z)))->x);
      ((half2*)(&(__1.z)))->y = (((half2*)(&(v_.z)))->y+((half2*)(&(v__1.z)))->y);
      ((half2*)(&(__1.w)))->x = (((half2*)(&(v_.w)))->x+((half2*)(&(v__1.w)))->x);
      ((half2*)(&(__1.w)))->y = (((half2*)(&(v_.w)))->y+((half2*)(&(v__1.w)))->y);
    *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_14 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072)) = __1;
  }
  __syncthreads();
  #pragma unroll
  for (int i_15 = 0; i_15 < 4; ++i_15) {
    *(uint4*)(C + (((((((int)blockIdx.y) * 65536) + (i_15 * 16384)) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8))) = *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_15 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072));
  }

currently we disable atomicAdd when we have bias to skip this situation.

Reproduce:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bitblas
import bitblas.testing
from bitblas import Linear as BitBLASLinear
import torch
import time
import numpy as np
import torch.nn as nn

torch.manual_seed(0)
bitblas.set_log_level("DEBUG")


def correctness_consistent(m, in_features, out_features, bias):
    linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda())
    linear_bitblas = BitBLASLinear(
        in_features,
        out_features,
        bias=bias,
        A_dtype="float16",
        W_dtype="float16",
        accum_dtype="float16",
        out_dtype="float16",
        opt_M=m,
    ).cuda()

    with torch.no_grad():
        linear_bitblas.load_and_transform_weight(linear_torch.weight.clone())
        if bias:
            linear_bitblas.bias = nn.Parameter(linear_torch.bias.clone())

    with torch.no_grad():
        if not isinstance(m, int):
            # When m is a list, average m
            m = sum(m) // len(m)
        input_data = torch.randn(m, in_features, dtype=torch.float16).cuda()
        output_torch = linear_torch(input_data)
        output_bitblas = linear_bitblas(input_data)
    print(output_torch)
    print(output_bitblas)
    bitblas.testing.torch_assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2)


def test_correctness_consistent():
    correctness_consistent(1, 1024, 1024, False)
    correctness_consistent(1, 1024, 1024, True)
    correctness_consistent(1024, 1024, 1024, True)
    correctness_consistent([1, 1024], 1024, 1024, True)


def correctness_weight_only_dequantize(
    m,
    in_features,
    out_features,
    bias,
    W_dtype,
    group_size,
    with_scaling,
    with_zeros,
    zeros_mode,
):
    import numpy as np
    from bitblas.quantization.utils import general_compress
    from bitblas.cache import global_operator_cache

    global_operator_cache.clear()
    linear_bitblas = BitBLASLinear(
        in_features,
        out_features,
        bias=bias,
        A_dtype="float16",
        W_dtype=W_dtype,
        accum_dtype="float16",
        out_dtype="float16",
        group_size=group_size,
        with_scaling=with_scaling,
        with_zeros=with_zeros,
        opt_M=m,
    ).cuda()
    if not isinstance(m, int):
        # average m
        m = sum(m) // len(m)
    input_shape = (m, in_features)
    weight_shape = (out_features, in_features)
    output_shape = (m, out_features)
    inputs = []
    inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
    source_format, bit = (
        linear_bitblas.bitblas_matmul.source_format,
        linear_bitblas.bitblas_matmul.bit,
    )

    maxq = 2**(bit - 1)
    zeros = maxq
    if source_format == "uint":
        inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
    elif source_format == "int":
        inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda())
    else:
        raise NotImplementedError

    inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())

    intweight = inputs[1]
    intweight = intweight.cpu().to(torch.int8)
    if source_format == "int":
        intweight = intweight + maxq
    if with_zeros:
        inputs[1] = inputs[1] - zeros
    bias_tensor = torch.rand((output_shape[-1],), dtype=torch.float16).cuda()
    ref_result = torch.matmul(inputs[0], (inputs[1].t()).to(torch.float16))
    if bias:
        ref_result = ref_result + bias_tensor

    with torch.no_grad():
        permuted_inputs = []
        permuted_inputs.append(inputs[0])
        if linear_bitblas.bitblas_matmul.weight_transform is not None:
            permuted_inputs.append(
                linear_bitblas.bitblas_matmul.weight_transform(intweight.cpu()).cuda())
        else:
            permuted_inputs.append(inputs[1])
        linear_bitblas.qweight.data = permuted_inputs[-1].clone()
        if with_scaling:
            if group_size == -1:
                group_size = in_features
            permuted_inputs.append(
                torch.ones([out_features, in_features // group_size], dtype=torch.float16).cuda())
            linear_bitblas.scales.data = permuted_inputs[-1].clone()
        if with_zeros:
            if zeros_mode == "original":
                permuted_inputs.append(
                    torch.ones([out_features, in_features // group_size],
                               dtype=torch.float16).cuda() * zeros)
            elif zeros_mode == "rescale":
                original_zeros = (
                    torch.ones([out_features, in_features // group_size],
                               dtype=torch.float16).cuda() * zeros)
                scaled_zeros = original_zeros * permuted_inputs[-1]
                permuted_inputs.append(scaled_zeros)
            elif zeros_mode == "quantized":
                original_zeros = (
                    torch.ones([in_features // group_size, out_features], dtype=torch.int8).cuda() *
                    zeros)
                qzeros = general_compress(
                    original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
                permuted_inputs.append(torch.from_numpy(qzeros).cuda())
            else:
                raise NotImplementedError
            linear_bitblas.zeros.data = permuted_inputs[-1].clone()
        if bias:
            permuted_inputs.append(bias_tensor)
            linear_bitblas.bias.data = bias_tensor.clone()

    with torch.no_grad():
        output_bitblas = linear_bitblas(inputs[0])

    rtol = 1e0
    atol = 1e0
    if zeros_mode == "original":
        rtol = 1e2
        atol = 1e2
    print(output_bitblas)
    print(ref_result)
    torch.testing.assert_close(output_bitblas, ref_result, rtol=rtol, atol=atol)


def test_correctness_weight_only_dequantize():
    correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
    correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
    correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)
    correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None)
    correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original")
    correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original")
    correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale")


def profile(model, input_data):
    model = model.cuda()
    model.eval()

    def get_runtime(num_repeats=1):
        tic = time.time()
        for _ in range(num_repeats):
            _ = model(input_data)
        torch.cuda.synchronize()
        return (time.time() - tic) * 1000 / num_repeats

    with torch.no_grad():
        # print("Warming up ...")
        st = time.time()
        while time.time() - st < 1.0:
            get_runtime()  # warmup
        warmup_runtime = get_runtime()
        num_repeats = max(1, int(1000 / warmup_runtime))
        times = get_runtime(num_repeats)
    return np.mean(times)


if __name__ == "__main__":
    # bitblas.testing.main()
    correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)

More clean script to help reproduce:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas import tvm as tvm
import bitblas.testing
from tvm import tl
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
    MatmulBlockScheduler,
    MatmulFineGrainScheduler,
    MatmulWeightPropagationScheduler,
)

from bitblas.ops.general_matmul.tilelang.dequantize import (
    MatmulDequantizeScheduler,
    MatmulDequantizeFineGrainedScheduler,
    MatmulDequantizeWeightPropagationScheduler,
    MatmulINT4DequantizeFineGrainedScheduler,
    MatmulINT4DequantizeWeightPropagationScheduler,
)

from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
    MatmulINT4FineGrainScheduler,
    MatmulINT4WeightPropagationScheduler,
)

import torch
import torch.backends

torch.manual_seed(0)

verbose = False


def assert_matmul_fine_grained_dequant_with_default_correctness(
    M,
    N,
    K,
    trans_A=False,
    trans_B=True,
    in_dtype="float16",
    out_dtype="float16",
    accum_dtype="float16",
    bit=4,
    storage_dtype="int8",
    source_format="uint",
    with_scaling=False,
    with_zeros=False,
    group_size=-1,
    fast_decoding=False,
    zeros_mode="original",
    with_bias=False,
    split_k_factor=1,
):
    import numpy as np
    from bitblas.quantization import general_compress, interleave_weight

    matmul = MatmulDequantizeFineGrainedScheduler(
        M=M,
        N=N,
        K=K,
        trans_A=trans_A,
        trans_B=trans_B,
        in_dtype=in_dtype,
        out_dtype=out_dtype,
        accum_dtype=accum_dtype,
        num_bits=bit,
        storage_dtype=storage_dtype,
        source_format=source_format,
        with_scaling=with_scaling,
        with_zeros=with_zeros,
        group_size=group_size,
        fast_decoding=fast_decoding,
        zeros_mode=zeros_mode,
        with_bias=with_bias,
    ).apply_config(
        block_row_warps=2,
        block_col_warps=2,
        warp_row_tiles=32,
        warp_col_tiles=32,
        chunk=32,
        num_stages=0,
        enable_rasterization=False,
        split_k_factor=split_k_factor,
    )

    mod, params = tl.lower(matmul)
    src_code = mod.imported_modules[0].get_source()
    # src_code is the generated cuda source
    assert src_code is not None
    input_shape = (M, K)
    weight_shape = (N, K)
    output_shape = (M, N)
    bias_shape = (N, )
    inputs = []
    inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
    maxq = 2**(bit - 1)
    zeros = maxq
    if source_format == "uint":
        inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
    elif source_format == "int":
        inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda())
    else:
        raise NotImplementedError
    bias = torch.ones(bias_shape, dtype=torch.float16).cuda()

    inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())

    intweight = inputs[1]
    intweight = intweight.cpu().to(torch.int8)
    if source_format == "int":
        intweight = intweight + maxq
    if with_zeros:
        inputs[1] = inputs[1] - zeros

    permuted_inputs = []
    permuted_inputs.append(inputs[0])
    qw = general_compress(intweight.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
    # lop3 transformation
    if fast_decoding:
        qw = interleave_weight(qw, bit, target_dtype=in_dtype)
    permuted_inputs.append(torch.from_numpy(qw).cuda())
    if with_scaling:
        if group_size == -1:
            group_size = K
        permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda())
    if with_zeros:
        if zeros_mode == "original":
            permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda())
        elif zeros_mode == "rescale":
            original_zeros = (torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros)
            scaled_zeros = original_zeros * permuted_inputs[-1]
            permuted_inputs.append(scaled_zeros)
        elif zeros_mode == "quantized":
            original_zeros = (torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros)
            qzeros = general_compress(
                original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
            permuted_inputs.append(torch.from_numpy(qzeros).cuda())
        else:
            raise NotImplementedError
    
    if with_bias:
        permuted_inputs.append(bias)

    permuted_inputs.append(inputs[2])

    mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)

    mod(*permuted_inputs)

    print(permuted_inputs[-1])

    args = [inputs[0]]
    b = inputs[1]
    if with_scaling:
        scale = permuted_inputs[2]
        rescale_b = torch.empty_like(b, dtype=torch.float16)
        for i in range(N):
            for j in range(K):
                if with_zeros:
                    if zeros_mode == "original":
                        rescale_b[i, j] = (b[i, j] - zeros) * scale[i, j // group_size]
                    elif zeros_mode == "rescale":
                        rescale_b[i, j] = (b[i, j] * scale[i, j // group_size] + zeros)
                    else:
                        raise NotImplementedError
                else:
                    rescale_b[i, j] = b[i, j] * scale[i, j // group_size]
        args.append(rescale_b.t().cuda())
    else:
        args.append(b.t().cuda().to(torch.float16))

    ref_result = torch.matmul(*args)
    if with_bias:
        ref_result = ref_result + bias
    
    print(ref_result)
    bitblas.testing.torch_assert_close(permuted_inputs[-1], ref_result, rtol=1e-1, atol=1e-1)


def test_matmul_fine_grained_dequant_with_default():
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024, 1024, 1024, source_format="uint", bit=4)
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024, 1024, 1024, source_format="uint", bit=2)
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True)
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024,
        1024,
        1024,
        source_format="uint",
        bit=4,
        with_scaling=True,
        with_zeros=True,
    )
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True)
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024,
        1024,
        1024,
        source_format="uint",
        bit=4,
        with_scaling=True,
        fast_decoding=True,
    )
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024,
        1024,
        1024,
        source_format="uint",
        bit=4,
        with_scaling=True,
        with_zeros=True,
        fast_decoding=True,
    )



if __name__ == "__main__":
    # non-splitk + non-bias
    assert_matmul_fine_grained_dequant_with_default_correctness(
        128, 128, 128, source_format="int", bit=4, with_bias=False, split_k_factor=1)
    # non-splitk + bias
    assert_matmul_fine_grained_dequant_with_default_correctness(
        128, 128, 128, source_format="int", bit=4, with_bias=True, split_k_factor=1)
    # atomicAdd + non-bias
    assert_matmul_fine_grained_dequant_with_default_correctness(
        128, 128, 128, source_format="int", bit=4, with_bias=False, split_k_factor=2)
    # atomicAdd + bias
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024, 1024, 1024, source_format="uint", bit=4, with_bias=True, split_k_factor=2)

Interesting bug and resolved, as blockIdx.z represents k-dimension during splitk implementation, so bias add must be done in only one blockZDim, otherwise bias will be added multiple times.

Closed as be resolved in pr #270