/fsdp_qlora

Training LLMs with QLoRA + FSDP

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

fsdp_qlora

Training LLMs with Quantized LoRA + FSDP.

Read our announcement blog post.

You should treat this script as an alpha/preview release. If you’re not comfortable with testing and debugging models, we’d suggest holding off for a few months while the community more fully tests the approach.

Integrations

FSDP+QLoRA has been integrated into:

Installation

The following steps should work (tested on Cuda 11.7, 11.8 and 12.1):

  • Clone https://github.com/AnswerDotAI/fsdp_qlora
  • pip install llama-recipes fastcore "transformers!=4.38.*,!=4.39.*" --extra-index-url https://download.pytorch.org/whl/test/cu118 as an easy way to get most dependencies (replace 118 with your desired Cuda version)
  • Install bitsandbytes pip install bitsandbytes>=0.43.0
  • Run huggingface-cli login (to access Llama 2)
  • Optional Libraries:
    • HQQ quantization: follow the HQQ installation instructions. Our training script uses HQQBackend.ATEN_BACKPROP, so also make sure to build the custom kernels cd hqq/kernels && python setup_cuda.py install. Pin commit to 72b2b641aadc44a7ded6b243915f90df3b3be385 for FSDP compatibility, until to_empty() method is fixed.
    • Weights and Biases logging: pip install wandb
  • Pytorch >= 2.2 is recommended to make use of the native flash-attention 2 kernel.

Finetune Llama-2 70B on Dual 24GB GPUs

Once installed, run cd fsdp_qlora and then run the following command to begin finetuning Llama-2 70B on Alpaca at a maximum sequence length of 512 tokens.

python train.py \
--model_name meta-llama/Llama-2-70b-hf \
--batch_size 2 \
--context_length 512 \
--precision bf16 \
--train_type qlora \
--use_gradient_checkpointing true \
--use_cpu_offload true \
--dataset alpaca \
--reentrant_checkpointing true

This example command currently uses just over 128GB of CPU RAM. If you only have 128GB available, we recommend making a 10-20GB swap file to accommodate the initial spike in usage.

Training Options

For quantization we support HQQ and bitsandbytes. We're currently doing benchmarking to help you decide which to use. If you do use bitsandbytes, be sure to pass --reentrant_checkpointing True to avoid triggering a bug in bitsandbytes which results in high memory usage (a fix is in progress).

--train_type full

Full params fine-tuning.

export CUDA_VISIBLE_DEVICES=4,5 # optionally set devices
python train.py \
--world_size 2 \ # optional, on a single machine will be set automatically
--master_port 12356 \ # optional, defaults to 12355
--model_name meta-llama/Llama-2-7b-hf \
--gradient_accumulation_steps 4 \
--batch_size 8 \
--context_length 512 \
--precision bf16 \
--train_type full \
--use_gradient_checkpointing true \
--use_cpu_offload false \
--use_activation_cpu_offload false \
--log_to wandb \
--dataset alpaca

--train_type lora

LoRA fine-tuning using HF PEFT library.

- --train_type full \
+ --train_type lora \

--train_type custom_lora

LoRA fine-tuning using a custom LoRA module.

- --train_type full \
+ --train_type custom_lora \

--train_type qlora

4-bit quantized LoRA fine-tuning using bitsanbytes Linear4bit layer with NF4 quantization and HF PEFT library.

- --train_type full \
+ --train_type qlora \
+ --reentrant_checkpointing true \

--train_type custom_qlora

4-bit quantized LoRA fine-tuning using bitsanbytes Linear4bit layer with NF4 quantization and a custom LoRA module.

- --train_type full \
+ --train_type custom_qlora \
+ --reentrant_checkpointing true \

--train_type hqq_lora

4-bit quantized LoRA fine-tuning using HQQ library and a custom LoRA module.

- --train_type full \
+ --train_type hqq_lora \

--train_type bnb_dora

4-bit quantized DoRA fine-tuning using bitsanbytes Linear4bit layer with NF4 quantization and a custom DoRA module.

- --train_type full \
+ --train_type bnb_dora \

--train_type hqq_dora

4-bit quantized DoRA fine-tuning using HQQ library and a custom DoRA module.

- --train_type full \
+ --train_type hqq_dora \

--train_type bnb_llama_pro

4-bit quantized Llama-Pro fine-tuning using bitsanbytes Linear4bit layer with NF4 quantization.

To create llama-pro weights, run the following command:

python scripts/block_expansion.py \
--model_name meta-llama/Llama-2-7b-hf \
--output_dir /path/to/llama_pro_weights_directory \
--expansion_rate 0.1
- --train_type full \
+ --train_type bnb_llama_pro \
+ --llama_pro_path /path/to/llama_pro_weights_directory \

--train_type hqq_llama_pro

4-bit quantized Llama-Pro fine-tuning using HQQ library.

To create llama-pro weights, run the following command:

python scripts/block_expansion.py \
--model_name meta-llama/Llama-2-7b-hf \
--output_dir /path/to/llama_pro_weights_directory \
--expansion_rate 0.1
- --train_type full \
+ --train_type hqq_llama_pro \
+ --llama_pro_path /path/to/llama_pro_weights_directory \

Low Memory Loading

During quantized LoRA training we use a custom quantization and loading code to avoid loading the entire model into GPU memory before sharding it across GPUs. This is the default behavior of our training script when any of the following training options "qlora", "custom_qlora", "hqq_lora" is used. Other training options are already optimized for low memory loading to their best extent.

We load the weights iteratively, quantize them on the GPU and place them back to CPU or meta device (based on their rank) concurrently a few layers at a time. We do this across all GPUs to initialize the quantization parameters, such as zero and scale, while using sync_module_states=True to sync the model parameters and buffers across all GPUs during FSDP initialization.

Mixed Precision Training

--precision bf16 (pure bfloat16)

This will cast all the model parameters to torch.bfloat16 before training and won't use FSDP mixed precision. As a result, sharded and unsharded params will be stored in bf16, forward and backward passes will be done in bf16, and gradient reduction and updates will be done in bf16.

--precision fp32 (pure float32)

This will cast all the model parameters to torch.float32 before training and won't use FSDP mixed precision. As a result, sharded and unsharded params will be stored in fp32, forward and backward passes will be done in fp32, and gradient reduction and updates will be done in fp32.

--precision mp_fp16_autocast (mixed float16 with autocast)

This will cast all the model parameters to torch.float32 before training and will use FSDP mixed precision with

mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32)

As a results, sharded and unsharded params will be stored in fp32. It will use autocast(torch.float16) for forward and backward passes, and autocast(torch.float16) for gradient reduction and updates.

--precision mp_bf16_autocast (mixed bfloat16 with autocast)

This will cast all the model parameters to torch.float32 before training and will use FSDP mixed precision with

mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32)

As a results, sharded and unsharded params will be stored in fp32. It will use autocast(torch.bfloat16) for forward and backward passes, and autocast(torch.bfloat16) for gradient reduction and updates.

--precision mp_bf16_buffers_autocast (bfloat16 params and float32 buffers with autocast)

This will cast all the model parameters to torch.bfloat16 before training but will keep the buffers in torch.float32 and will use FSDP mixed precision with

mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)

As a results, sharded and unsharded params will be stored in bf16. It will use autocast(torch.bfloat16) for forward and backward passes, and autocast(torch.bfloat16) for gradient reduction and updates. Buffers and only eligible operations in autocast will be performed in bf16.

This option is important for RoPE layer which gives incorrect results when cast to lower precision especially with longer context lengths.

Comparison to an existing trainer

Screenshot 2024-02-01 083222 hf_train.py uses TRL's SFTTrainer for a comparison run. To match with our script, modify the dataloading code to train on everything (not just completions) and then run train.py --train_type qlora --dataset guanaco --batch_size 8 --lr_scheduler cosine --log_to wandb --save_model True --output_dir guanaco_7B --gradient_accumulation_steps 2 --lr 2e-4. The SFTTrainer version has to run with a lower batch size (4 vs 8) so we only do 2 gradient accumulation steps vs 4 in the QLoRA+FSDP version.

Converting Saved Models

If you specify --save_model True the adapter layers will be saved as a state dict. To convert to the regular Hugging Face format and upload to the hub, see: Converting the State Dict.ipynb

If "custom_qlora", "hqq_lora" training options are used, then only the trainable LoRA parameters will be saved. Before inference, you need to load and quantize the base model again, and separately load the saved LoRA parameters.

You can alternatively test to see if merging base model weights and trained LoRA weights and then quantizing them performs similar to keeping the parameters separately as done during training. To make use of torch.compile with HQQ, see mobiusml/hqq#18.

Limitations

While QLoRA finetuning works with FSDP, there are some rough edges to be aware of with this alpha release and our example script.

First, the current release of Transformer AutoModel.from_pretrained cannot be used to load models into quantized weights, as it does not support the new quant_storage or quantization flag. Loading pretrained models requires writing or using custom model loading code. We provide an example of how to load and quantize a QLoRA model for finetuning in our demo script.

We are actively working with Hugging Face to resolve this incompatibility in future Transformers and PEFT releases.

Second, while FSDP’s Mixed Precision works with QLoRA, practitioners need to be careful to set the MixedPrecision.param_type to match the Linear4Bit.quant_storage dtype. Otherwise, FSDP’s Mixed Precision could cast the quantized weights to a different precision, essentially turning them into random weights. Our example script shows how to avoid this potential pitfall, and we will be happy to assist model training libraries in correctly exposing FSDP’s Mixed Precision options to users when training with QLoRA

Example: Llama 70B 4-A100 40GB Training

# BnB QLoRA
export CUDA_VISIBLE_DEVICES=4,5,6,7
python train.py \
--world_size 4 \
--master_port 12356 \
--model_name meta-llama/Llama-2-70b-hf \
--gradient_accumulation_steps 4 \
--batch_size 2 \
--context_length 512 \
--precision bf16_buffers_autocast \
--train_type custom_qlora \
--use_gradient_checkpointing true \
--reentrant_checkpointing true
--use_cpu_offload false \
--log_to stdout \
--dataset alpaca

# HQQ QLoRA
export CUDA_VISIBLE_DEVICES=4,5,6,7
python train.py \
--world_size 4 \
--master_port 12356 \
--model_name meta-llama/Llama-2-70b-hf \
--gradient_accumulation_steps 4 \
--batch_size 2 \
--context_length 512 \
--precision bf16_buffers_autocast \
--train_type hqq_lora \
--use_gradient_checkpointing true \
--use_cpu_offload false \
--log_to stdout \
--dataset alpaca

Note: For large batch size or long context training HQQ LoRA is a bit more memory efficient compared to BnB LoRA with re-entrant checkpointing. So if you are running into OOM issues, try using HQQ LoRA.

SLURM Training

See fsdp_multi_node.sh for an example training script using multi-node training with SLURM.

Add support for a new model

First, import the new model's transformer, attention, and MLP layers from Transformers:

from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MISTRAL_ATTENTION_CLASSES, MistralMLP

Then in the get_wrapping_policy function, add the attention, MLP, and transformer layers to the self_attn_policy_fn, mlp_policy_fn, and transformer_wrap_policy wrapping policy methods:

def get_wrapping_policy(custom_policy:bool=False):

    def self_attn_policy_fn(module):
        return isinstance(module, tuple(*LLAMA_ATTENTION_CLASSES.values(), *MISTRAL_ATTENTION_CLASSES.values()))

    def mlp_policy_fn(module):
        return isinstance(module, (LlamaMLP, MistralMLP))

    transformer_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls=(LlamaDecoderLayer, MistralDecoderLayer),
    )

Finally, add gradient checkpointing support by adding the transformer layer to check_fn:

if args["use_gradient_checkpointing"]:
    check_fn = lambda submodule: isinstance(submodule, (LlamaDecoderLayer, MistralDecoderLayer))