microsoft/BitBLAS

CUDA error: an illegal memory access was encountered when using BitBlas on multiple GPUs

mobicham opened this issue · 12 comments

There seems to be an issue with BitBlas when using multi-gpu, even allocating a new array causes this problem:

In [8]: model_loaded.model.layers[0].self_attn.q_proj.device
Out[8]: 0

In [11]: model_loaded.model.layers[16].self_attn.q_proj.device
Out[11]: 1

In [12]: l = model_loaded.model.layers[16].self_attn.q_proj
In [13]: x = torch.randn((1, 4096), device=l.device, dtype=torch.float16)
In [14]: out = l(x) #Runs OK

In [15]: l = model_loaded.model.layers[0].self_attn.q_proj
In [16]: x = torch.randn((1, 4096), device=l.device, dtype=torch.float16)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[16], line 1
----> 1 x = torch.randn((1, 4096), device=l.device, dtype=torch.float16)

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Tasks

Hi @mobicham , which framework are you using for this issue? would you mind provide the scripts for us to reproduce :)

Using BitBlas via hqq on transformers models:

  • pip install git+https://github.com/mobiusml/hqq.git
  • pip install bitblas
# Code tested with 2 x 24 GB gpus
import torch, gc
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

device    = 'auto'
dtype     = torch.float16
model_id  = 'meta-llama/Meta-Llama-3-8B-Instruct'
cache_dir = '.' 

quant_config  = HqqConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, axis=1)

model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=dtype, 
    cache_dir=cache_dir,
    device_map=device, 
    quantization_config=quant_config
)

#Patching
from hqq.utils.patching import *
from hqq.core.quantize import *
HQQLinear.set_backend(HQQBackend.PYTORCH)
prepare_for_inference(model, backend='bitblas', verbose=True) 

# Check if the layers have parameters on the right device
for layer_id in range(0, 14):
    print(layer_id, model.model.layers[layer_id].self_attn.q_proj.device) #cuda:0
    print(layer_id, model.model.layers[layer_id].self_attn.q_proj.W_q.device) #cuda:0
    print(layer_id, model.model.layers[layer_id].self_attn.q_proj.scale.device) #cuda:0
    print(layer_id, model.model.layers[layer_id].self_attn.q_proj.zero.device) #cuda:0
    print('--------------------------------------------------------')

for layer_id in range(15, len(model.model.layers)):
    print(layer_id, model.model.layers[layer_id].self_attn.q_proj.device) #cuda:1
    print(layer_id, model.model.layers[layer_id].self_attn.q_proj.W_q.device) #cuda:1
    print(layer_id, model.model.layers[layer_id].self_attn.q_proj.scale.device) #cuda:1
    print(layer_id, model.model.layers[layer_id].self_attn.q_proj.zero.device) #cuda:1
    print('--------------------------------------------------------')


# Allocate example
layer_id = 0
l        = model.model.layers[layer_id].self_attn.q_proj
x        = torch.randn((1, 4096), device=l.device, dtype=torch.float16)
out      = l(x) #Runs OK


layer_id = 16
l        = model.model.layers[layer_id].self_attn.q_proj
x        = torch.randn((1, 4096), device=l.device, dtype=torch.float16)
out      = l(x) #Runs OK

print(out) #RuntimeError: CUDA error: an illegal memory access was encountered


# #Test 
# input_tensor = torch.zeros((1, 8), dtype=torch.int32, device='cuda:0')
# with torch.no_grad():
# 	out_ref = model.forward(input_tensor)

Thank you!

hi @mobicham , amazing project and thanks for reporting.

I suspect there’s a reason behind transform_weight, given that we refactored this API to include weight compression in 0.0.1.dev.14. However, I’m not certain if this was the intended rationale. Let me take a look.

In the meantime, I believe we should consider reverting the change to ensure the bitblas-based project functions correctly.

Thank you @LeiWang1999
I noticed that transform_weights on a cuda:1 device tensor returns a cuda:0 device tensor for some reason. So I put it back on the right device https://github.com/mobiusml/hqq/blob/master/hqq/backends/bitblas.py#L115 but even when all the parameters (quantized weights, scale, zero) are on the right device, it produces that error

Hi @mobicham , looks like it works on my single gpu environment. btw, I'm using the latest release 0.0.1.dev15. I'll soon test it on my multi-gpu env.

Thanks for testing! Yeah it works fine on single gpu, the issue is only with multi-gpu. Let me know how it goes!

@mobicham , yeah I've identified the problem, this issue is similar to #105.

BitBLAS was initialized with the shape n=4096, k=4096, while the shard model dispatched n=4096 and k=2048.

image

One possible solution could be to make the initialization of the Linear Layer aware of the shard information, similar to how it’s done in vLLM.

Hm I am not sure how you got that, but in my example all seems correct. The code works perfectly fine on a single gpu for 7 models I tested.

    layer = model.model.layers[layer_id].self_attn.q_proj
    print(layer_id, layer.name, layer.device) #cuda:0
    print(layer_id, layer.name, layer.W_q.device, layer.W_q.shape, 'N', layer.out_features, 'K', layer.in_features) #cuda:0
    print(layer_id, layer.name, layer.scale.device, layer.scale.shape) #cuda:0
    print(layer_id, layer.name, layer.zero.device, layer.zero.shape) #cuda:0
    print(layer.eng_tag, layer.matmul_eng)
    print('--------------------------------------------------------')

prints

16 model.layers.16.self_attn.q_proj 1
16 model.layers.16.self_attn.q_proj cuda:1 torch.Size([4096, 2048]) N 4096 K 4096
16 model.layers.16.self_attn.q_proj cuda:1 torch.Size([4096, 64])
16 model.layers.16.self_attn.q_proj cuda:1 torch.Size([4096, 64])
torch.Size([4096, 4096])_4_64 <bitblas.ops.general_matmul.Matmul object at 0x7f904b238eb0>

W_q is int4 packed in int8, so its size should be 4096 x 2048 as a packed tensor. The unpacked size is 4096 x 4096, and it should be calling the 4096 x 4096 kernel, so it's all correct.

@mobicham , oh yeah thanks, sorry, that's my mistake, let me dig further.

Here's a minimal example:

from hqq.core.quantize import *
from hqq.backends.bitblas import patch_hqq_to_bitblas
HQQLinear.set_backend(HQQBackend.PYTORCH)

def init_layers(K, N, device, quant_config, compute_dtype=torch.float16):
    layer_0   = torch.nn.Linear(in_features=N, out_features=K, bias=False).to(device=device, dtype=compute_dtype)
    hqq_0     = HQQLinear(layer_0, quant_config, compute_dtype=compute_dtype, device=device, del_orig=False)
    bitblas_0 = patch_hqq_to_bitblas(hqq_0, None)
    layer_0.device = device
    layer_0.compute_dtype = compute_dtype
    return [layer_0, bitblas_0]

def eval(layers):
    for layer in layers:
        x        = torch.randn((1, layer.out_features), device=layer.device, dtype=layer.compute_dtype)
        with torch.no_grad():
            out = layer(x)
        print(layer, layer.device, 'OK')


quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1)
K, N         = 4096, 4096

layers = []
layers += init_layers(K, N, 'cuda:0', quant_config)
eval(layers)
# Linear(in_features=4096, out_features=4096, bias=False) cuda:0 OK
# HQQLinearBitBlas() cuda:0 OK
# Linear(in_features=4096, out_features=4096, bias=False) cuda:0 OK


layers += init_layers(K, N, 'cuda:1', quant_config)
eval(layers)

for i in range(5):
    eval(layers)
# RuntimeError: CUDA error: an illegal memory access was encountered
# CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
# For debugging consider passing CUDA_LAUNCH_BLOCKING=1
# Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Hello @mobicham, I've identified the bug related to kernels on multi-GPUs; we need to call torch.cuda.set_device prior to invoking the kernel.

The device set should ultimately be integrated within the BitBLAS operator's forward function. However, a hotfix can be implemented prior to the HQQ forward execution :)

Thank you! I just tried it mobiusml/hqq@f8c4519 and it works!
Closing this issue as a result.