hidet-org/hidet

[Bug] Hidet require more GPU memory than native torch compile

eric8607242 opened this issue · 2 comments

Hi,
thanks for the great work again!

Describe the bug
Currently, I want to compile the LLama-7B on a Nvidia 4090 GPU using Hidet for faster inference speed.
However, I have encountered an issue where Hidet requires more GPU memory compared to the native torch.compile method.

For example, given a input_ids with shape [2, 128]. The native torch.compile can compile the model successfully and infer the output correctly, but, when using Hidet, an OOM exception is raised during the first inference step.

I am wondering why Hidet requires more GPU memory.
Do I need some additional configurations to save the GPU memory?

To Reproduce
The following script is the example code to compile the LLaMA.

import torch
from transformers import LlamaModel, LlamaConfig
import hidet

print("Initialize the model")
configuration = LlamaConfig()
model = LlamaModel(configuration).half().eval().cuda()

BATCH_SIZE = 2
SEQ_LEN = 128
input_ids = torch.zeros(2, 128, dtype=torch.long).cuda()

with torch.no_grad():
    hidet.torch.dynamo_config.use_tensor_core(True)
    hidet.torch.dynamo_config.search_space(2) 
    hidet.torch.dynamo_config.use_fp16(True)
    hidet.torch.dynamo_config.use_fp16_reduction(True)

    print("Start to compile")
    # Compile the model using Hidet
    model_opt = torch.compile(model, backend='hidet')
    #model_opt = torch.compile(model, mode="reduce-overhead")
    print("Start to inference")
    model_opt(
        input_ids=input_ids,
        output_hidden_states=True
    )

However, the current Hidet has some unsupported operations for huggingface LLaMA.
I fork the repo and add some operations based on the latest version in this branch , which can compile the model successfully.

Enviroment

  • Python Package Requirements
transformers == 4.28.1
torch == 2.0.0
sentencepiece == 0.1.99
  • System Enviroments
    • OS: Ubuntu 22.04.1 LTS
    • GPU: RTX 4090
    • CUDA Version: 11.8
    • Driver Version: 520.61.05

Hi @eric8607242,

Glad to see that you have tried this model and added some operators by yourself!

I can reproduced this out-of-memory error. After digging into the depth, I found the main reason is that the weights are duplicated in pytorch and hidet. When we use torch.compile(...), the pytorch dynamo module will extract and dispatch torch.fx.Graph to the backends. The hidet backend will convert the pytorch tensor to hidet tensor, at this step, hidet can share the memory with pytorch thanks to the DLPack protocol. However, in the graph optimization stage, we will convert the weghts into other forms (e.g., transpose the weight of linear layer, concat the weights for the Q, K, V linear layers). Sadly, in this stage, we have to allocate memory for the created new tensors. Although we do not need the original weights anymore, there is no existing machanism to free them. A 7B FP16 model would takes about 14 GB GPU DRAM, and most of the weights are duplicated, and leads to out-of-memory error on a 3090 with 24 GB DRAM.

You could also have a try using ONNX frontend (i.e., exporting the model to ONNX format, and use hidet's onnx frontend). In the worst case, you can also try using similar method as this example.

But I still want you know that, using torch.compile(...) is the most important frontend for hidet, and we will spend more time on addressing above issues when we have a better support for dynamic shape and inference of LLM.

Hi @yaoyaoding,

Thanks for your clear answer!
Look forward to better support for the dynamic shape and inference of LLM.

Close the issue!