dvmazur/mixtral-offloading

Can it run on multi-GPU?

Opened this issue · 10 comments

drdh commented

Thanks for your contributions. I would like to know whether it can be deployed on multi-GPU to allow the use of more VRAM?

@dvmazur @lavawolfiee Can you please kindly address this question? I'd be happy to do this myself if it's not already possible, which I don't think it is, if you could point me to where I'd need to make changes.

Hi!

Sorry for the long reply.

Running the model on multi-GPU is not currently supported. Currently, all active experts are sent to cuda:0. You can send an expert to a different GPU by simply specifying a different device while initializing MixtralExpertWrapper.

Keep I'm mind that you would need to ballance the number of active experts between your GPUs. This logic could be added to the ExpertCache class.

By the way, one of our quantization setups compressed the model to 17Gb. This would fit into the VRAM of two T4 GPUs, which you can get for free on Kaggle.

Have you looked into running a quantized version (possibly ours) of the model using tensor_parallel?

Hi @dvmazur!

Thank you for your reply.

Unfortunately (or fortunately) I have 8 1080ti GPUs on my machine, which individually cannot seem to handle the model even with quantization and when offload_per_layer = 5 or offload_per_layer = 6. What I am ultimately trying to achieve is to run a single model on 2x 1080ti GPUs (total VRAM ~22.5Gb), so I can run 4 separate instances of the model across my GPUs simultaneously.

Thank you for you suggestions, I'll have a look at the MixtralExpertWrapper, but the tensor_parallel option with your specific quantization setup seems like a great workaround to try first. May I ask which quantization setup allowed compression down to 17Gb, or if you could point me to a file that contains that setup please? Currently when I set offload_per_layer = 5 the model seems to only occupy ~11Gb on a single GPU without an OOM error, but then at inference there's no utilization of the GPU cores throughout (though the VRAM is occupied) until the kernel crashes. Here's the code:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import sys

sys.path.append("mixtral-offloading")
import torch
from torch.nn import functional as F
from hqq.core.quantize import BaseQuantizeConfig
from huggingface_hub import snapshot_download
from IPython.display import clear_output
from tqdm.auto import trange
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import logging as hf_logging

from src.build_model import OffloadConfig, QuantConfig, build_model

model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
quantized_model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo"
state_path = "Mixtral-8x7B-Instruct-v0.1-offloading-demo"

config = AutoConfig.from_pretrained(quantized_model_name)

device = torch.device("cuda")

##### Change this to 5 if you have only 12 GB of GPU VRAM #####
# offload_per_layer = 4
offload_per_layer = 5
###############################################################

num_experts = config.num_local_experts

offload_config = OffloadConfig(
    main_size=config.num_hidden_layers * (num_experts - offload_per_layer),
    offload_size=config.num_hidden_layers * offload_per_layer,
    buffer_size=4,
    offload_per_layer=offload_per_layer,
)


attn_config = BaseQuantizeConfig(
    nbits=4,
    group_size=64,
    quant_zero=True,
    quant_scale=True,
)
attn_config["scale_quant_params"]["group_size"] = 256


ffn_config = BaseQuantizeConfig(
    nbits=2,
    group_size=16,
    quant_zero=True,
    quant_scale=True,
)
quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config)


model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)


tokenizer = AutoTokenizer.from_pretrained(model_name)
conversations_texts = ["can you summarise the book Love in the Time of Cholera in 500 words?", 
                       "can you summarise the book The Picture of Dorian Gray in 500 words?"]
batched_prompts = [f"User: {text} Assistant:" for text in conversations_texts]  # Prepare prompts
tokenizer.padding_side = "left" 
tokenizer.pad_token = tokenizer.eos_token # to avoid an error
# Tokenize all prompts as a batch
batch_inputs = tokenizer(batched_prompts, padding=True, return_tensors="pt", add_special_tokens=True).to("cuda")

# Generate responses for each prompt in the batch
outputs = model.generate(**batch_inputs, max_new_tokens=1000) #kernel dies!

The following pic shows the GPU utilization right before the kernel dies.

Screenshot 2024-04-02 153827

May I ask which quantization setup allowed compression down to 17Gb, or if you could point me to a file that contains that setup please?

It's the 4-bit attention and 2-bit expert setup from our tech-report. I suppose the weights can be found here. Let's summon @lavawolfiee just in case I'm mistaken.

the model seems to only occupy ~11Gb on a single GPU without an OOM error, but then at inference there's no utilization of the GPU cores throughout (though the VRAM is occupied) until the kernel crashes

Could you provide a bit more detail? I'll look into it as soon as I have the time to.

It's the 4-bit attention and 2-bit expert setup from our tech-report. I suppose the weights can be found here.

Yes, you're right

May I ask which quantization setup allowed compression down to 17Gb, or if you could point me to a file that contains that setup please?

It's the 4-bit attention and 2-bit expert setup from our tech-report. I suppose the weights can be found here. Let's summon @lavawolfiee just in case I'm mistaken.

This seems to be the same setup I have used in the code I provided, which occupies ~11Gb VRAM and ~23Gb of CPU RAM and then crashes the kernel at inference.

the model seems to only occupy ~11Gb on a single GPU without an OOM error, but then at inference there's no utilization of the GPU cores throughout (though the VRAM is occupied) until the kernel crashes

Could you provide a bit more detail? I'll look into it as soon as I have the time to.

Absolutely, what information are you looking for?

Absolutely, what information are you looking for?

A stacktrace would be helpful.