/Step-DPO

Implementation for "Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs"

Primary LanguagePython

image

Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs

Xin Lai, Zhuotao Tian, Yukang Chen, Senqiao Yang, Xiangru Peng, Jiaya Jia

Code License Data License Weight License

This repo provides the implementation of Step-DPO, a simple, effective, and data-efficient method for boosting the long-chain reasoning ability of LLMs, with a data construction pipeline that yields a high-quality dataset containing 10K step-wise preference pairs.

Notably, Step-DPO boosts the performance of Qwen2-7B-Instruct from 53.0% to 58.6% on MATH, and 85.5% to 87.9% on GSM8K, with as few as 10K data and hundreds of training steps!

Moreover, Step-DPO, when applied to Qwen2-72B-Instruct, achieves scores of 70.8% and 94.0% on the test sets of MATH and GSM8K, respectively, surpassing a series of closed-source models without bells and wistles, including GPT-4-1106, Claude-3-Opus, and Gemini-1.5-Pro.

image

TABLE OF CONTENTS

  1. News
  2. Datasets
  3. Models
  4. Installation
  5. Training
  6. Evaluation
  7. Data Construction Pipeline
  8. Deployment
  9. Examples
  10. Acknowledgement
  11. Citation

News

  • [2024.7.7] We release the scripts for Data Construction Pipeline! You can construct dataset on your own with these scripts!
  • [2024.7.1] We release the demo of the model Qwen2-7B-Instruct-Step-DPO. Welcome to try it on Demo!
  • [2024.6.28] We release the pre-print of Step-DPO and this GitHub repo, including training/evaluation scripts, pre-trained models and data.

Datasets

We build a 10K math preference datasets for Step-DPO, which can be downloaded from the following link.

Dataset Size Link
xinlai/Math-Step-DPO-10K 10,795 🤗 Hugging Face

Models

It is notable that the model Qwen2-72B-Instruct + Step-DPO could achieve 70.8% and 94.0% on MATH and GSM8K test sets. Step-DPO also brings considerable improvement over various models as follows. Welcome to download and use.

Models Size MATH GSM8K Odyssey-MATH Link
Qwen2-7B-Instruct 7B 53.0 85.5 - -
Qwen2-7B-Instruct + Step-DPO 7B 58.6 (+5.6) 87.9 (+2.4) - 🤗 HF
DeepSeekMath-RL 7B 51.7 88.2 - -
DeepSeekMath-RL + Step-DPO 7B 53.2 (+1.5) 88.7 (+0.5) - 🤗 HF
Qwen2-7B-SFT 7B 54.8 88.2 - 🤗 HF
Qwen2-7B-SFT + Step-DPO 7B 55.8 (+1.0) 88.5 (+0.3) - 🤗 HF
Qwen1.5-32B-SFT 32B 54.9 90.0 - 🤗 HF
Qwen1.5-32B-SFT + Step-DPO 32B 56.9 (+2.0) 90.9 (+0.9) - 🤗 HF
Qwen2-57B-A14B-SFT 57B 54.6 89.8 - 🤗 HF
Qwen2-57B-A14B-SFT + Step-DPO 57B 56.5 (+1.9) 90.0 (+0.2) - 🤗 HF
Llama-3-70B-SFT 70B 56.9 92.2 - 🤗 HF
Llama-3-70B-SFT + Step-DPO 70B 59.5 (+2.6) 93.3 (+1.1) - 🤗 HF
Qwen2-72B-SFT 72B 61.7 92.9 44.2 🤗 HF
Qwen2-72B-SFT + Step-DPO 72B 64.7 (+3.0) 93.9 (+1.0) 47.0 (+2.8) 🤗 HF
Qwen2-72B-Instruct 72B 69.4 92.4 47.0 -
Qwen2-72B-Instruct + Step-DPO 72B 70.8 (+1.4) 94.0 (+1.6) 50.1 (+3.1) 🤗 HF

Note: Odyssey-MATH contains competition-level math problems.

Installation

conda create -n step_dpo python=3.10
conda activate step_dpo

pip install -r requirements.txt

Training

Pre-trained weights

We use Qwen2, Qwen1.5, Llama-3, and DeepSeekMath models as the pre-trained weights and fine-tune them with Step-DPO. Download based on your choices.

Pre-trained weights
Qwen/Qwen2-7B-Instruct
deepseek-ai/deepseek-math-7b-rl
xinlai/Qwen2-7B-SFT
xinlai/Qwen1.5-32B-SFT
xinlai/Qwen2-57B-A14B-SFT
xinlai/Llama-3-70B-SFT
xinlai/Qwen2-72B-SFT
Qwen/Qwen2-72B-Instruct

Note: models with '-SFT' are supervised fine-tuned by our 299K SFT data based on open-source base models. You could perform Step-DPO on either our SFT models or existing open-source instruct models.

Here is a script example to perform Step-DPO on Qwen/Qwen2-72B-Instruct:

ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3_cpu.yaml --mixed_precision bf16 \
    --num_processes 8 \
    train.py configs/config_full.yaml \
    --model_name_or_path="Qwen/Qwen2-72B-Instruct" \
    --data_path="xinlai/Math-Step-DPO-10K" \
    --per_device_train_batch_size=2 \
    --gradient_accumulation_steps=8 \
    --torch_dtype=bfloat16 \
    --bf16=True \
    --beta=0.4 \
    --num_train_epochs=4 \
    --save_strategy='steps' \
    --save_steps=200 \
    --save_total_limit=1 \
    --output_dir=outputs/qwen2-72b-instruct-step-dpo \
    --hub_model_id=qwen2-72b-instruct-step-dpo \
    --prompt=qwen2-boxed

Evaluation

Here are script examples to evaluate fine-tuned models on both GSM8K and MATH test sets:

python eval_math.py \
    --model outputs/qwen2-72b-instruct-step-dpo \
    --data_file ./data/test/GSM8K_test_data.jsonl \
    --save_path 'eval_results/gsm8k/qwen2-72b-instruct-step-dpo.json' \
    --prompt 'qwen2-boxed' \
    --tensor_parallel_size 8
python eval_math.py \
    --model outputs/qwen2-72b-instruct-step-dpo \
    --data_file ./data/test/MATH_test_data.jsonl \
    --save_path 'eval_results/math/qwen2-72b-instruct-step-dpo.json' \
    --prompt 'qwen2-boxed' \
    --tensor_parallel_size 8

Data Construction Pipeline

We release the scripts to construct the Step-DPO data, as shown in the data_pipeline/ directory. Please follow the instructions below.

cd Step-DPO

# Step 1: Error Collection
# Before executing, please set the MODEL_PATH, PRED_PATH, EVAL_PROMPT
bash data_pipeline/step1.sh

# Step 2: Locate Erroneous Step by GPT-4o
# Before executing, please set the OPENAI_BASE_URL, OPENAI_API_KEY
bash data_pipeline/step2.sh

# Step 3: Rectify by the model itself
# Before executing, please set the MODEL_PATH, EVAL_PROMPT, JSON_FILE, PRED_PATH, SAVE_PATH
bash data_pipeline/step3.sh

# Finally, Get the resulting dataset
# Before executing, please set the EVAL_PROMPT, JSON_FILE, PRED_PATH, SAVE_PATH
bash data_pipeline/merge.sh

Deployment

For deployment, please directly use the following command:

python3 app.py --model_path_or_name xinlai/Qwen2-7B-Instruct-Step-DPO

Examples

image

image

image

image

Acknowledgement

This repository is based on alignment-handbook, DeepSeekMath, and MetaMath.

Many thanks for their efforts!

Citation

If you find this project useful in your research, please consider citing us:

@article{lai2024stepdpo,
  title={Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs},
  author={Xin Lai and Zhuotao Tian and Yukang Chen and Senqiao Yang and Xiangru Peng and Jiaya Jia},
  journal={arXiv:2406.18629},
  year={2024}
}