/RoSA

Primary LanguagePythonApache License 2.0Apache-2.0

Robust Adaptation (RoSA)

This repository includes the code for the paper "RoSA: Accurate Parameter-Efficient Fine-Tuning via Robust Adaptation." Below you find an illustration of RoSA and a brief comparison with Full Fine-Tuning (FFT) and Low-Rank Adaptation (LoRA).

Installation

  1. Create a clean environment and activate it:
conda create --name rosa python=3.10 -y
conda activate rosa
  1. Install a version of pytorch (>=2.1.2) compatible with your CUDA (please use conda instead of pip to ensure all the dependencies are installed properly). For example, if you have CUDA version 11.8, run the following command:
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
  1. Install this repository, which is a fork of MosaicML's llm-foundry including the experiments presented in the paper:
git clone https://github.com/IST-DASLab/RoSA.git && cd RoSA
pip install -e .
  1. Install the spops library, which we use under the hood to perform sparse operations:
pip install git+https://github.com/IST-DASLab/spops.git
  1. Install RoSA's integration into huggingface's Parameter-Efficient Fine-Tuning (PEFT) library by running:
pip install git+https://github.com/IST-DASLab/peft-rosa.git
  1. For evaluation, we use lm-evaluation-harness. Run the following commands to install the compatible version:
git clone https://github.com/EleutherAI/lm-evaluation-harness.git
cd lm-evaluation-harness
git checkout 2c18e367c6ded428863cd1fd4cf9558ca49d68dc
pip install -e .
cd ..

Quick Start

Training

First things first, activate the environment and cd into scripts/train/

conda activate rosa
cd scripts/train/

We provide scripts for training LLaMA-2 models on three datasets: GSM8k, ViGGO, and SQL. These datasets are chosen such that they are highly specialized and, therefore, require fine-tuning for good performance: for example, on GSM8k, the pre-trained LLaMA-2 model has 0% one-shot accuracy, and the multi-shot accuracy is also very poor (around 6%). To run quick experiments, simply run any of the following commands, each of which corresponds to one of the single-epoch experiments in the paper:

# RoSA on gsm8k
CUDA_VISIBLE_DEVICES=0 bash scripts/llama2-7b/restart_7b_gsm_bf16.sh

# RoSA on viggo
CUDA_VISIBLE_DEVICES=0 bash scripts/llama2-7b/restart_7b_viggo_bf16.sh

# RoSA on sql
CUDA_VISIBLE_DEVICES=0 bash scripts/llama2-7b/restart_7b_sql_bf16.sh

# QRoSA on gsm8k
CUDA_VISIBLE_DEVICES=0 bash scripts/llama2-7b/restart_7b_gsm_4bit.sh

# QRoSA on viggo
CUDA_VISIBLE_DEVICES=0 bash scripts/llama2-7b/restart_7b_viggo_4bit.sh

# QRoSA on sql
CUDA_VISIBLE_DEVICES=0 bash scripts/llama2-7b/restart_7b_sql_4bit.sh

Training on the GSM8k, ViGGO, and SQL should roughly take around one, one, and three hours, respectively. These scripts essentially run scripts/restarter_llama2.sh with different hyper-parameters. scripts/restarter_llama2.sh takes care of low-rank adapter warmup and restarting the training after mask generation. Feel free to tweak the hyper-parameters in any of these scripts.

Evaluation

The training scripts will run the evaluation right after the training is finished and store the results in the evals folder. Look at the final few lines of scripts/restarter_llama2.sh.

Evaluation on ViGGO and SQL only takes a few minutes. However, evaluation on GSM8k takes around 45 minutes for bf16 models and 3 hours for 4bit models (since merging the RoSA adapters in the 4bit case is tricky, and the current version of the code does not support it).

RoSA Results

Below is a comparison between Full Fine-Tuning (FFT), Low-Rank Adaptation (LoRA), pure Sparse Adaptation (SpA), and Robust Adaptation (RoSA). The first table shows results for the case where the pre-trained parameters are stored in the bf16 format, while the second one presents results for 4-bit double-qunatinzed pre-trained parameters.

Summary of RoSA results

Summary of QRoSA results

Citation

If you plan to use our work in your projects, please consider citing our paper:

@article{nikdan2024rosa,
  title={RoSA: Accurate Parameter-Efficient Fine-Tuning via Robust Adaptation},
  author={Nikdan, Mahdi and Tabesh, Soroush and Crnčević, Elvir and Alistarh, Dan},
  journal={arXiv preprint arXiv:2401.04679},
  year={2024}
}