- (2024.2) (Internal Code-Reviewing) Code with GPT3.5 and TigerBot to be released.
- (2024.1) Prompt-OIRL has been accepted by ICLR'2024. We look forward to talking with you in Vienna!
- (2024.12) Prompt-OIRL has been presented at the NeurIPS conference. Thanks for all the invaluable feedback!
- (2023.10) Code with LLaMA2 has been released.
- (2023.10) Prompt-OIRL has been featured in a positioning paper as an example of inverse alignment.
- (2023.9) Prompt-OIRL has been selected as an oral presentation at the ENLSP workshop at NeurIPS'2023.
In this study, we aim to enhance the arithmetic reasoning ability of Large Language Models (LLMs) through zero-shot prompt optimization. We identify a previously overlooked objective of query dependency in such optimization and elucidate two ensuing challenges that impede the successful and economical design of prompt optimization techniques. One primary issue is the absence of an effective method to evaluate prompts during inference when the golden answer is unavailable. Concurrently, learning via interactions with the LLMs to navigate the expansive natural language prompting space proves to be resource-intensive. To address this, we introduce Prompt-OIRL, which harnesses offline inverse reinforcement learning to draw insights from offline prompting demonstration data. Such data exists as by-products when diverse prompts are benchmarked on open-accessible datasets. With Prompt-OIRL, the query-dependent prompt optimization objective is achieved by first learning an offline reward model. This model can evaluate any query-prompt pairs without accessing LLMs. Subsequently, a best-of-N strategy is deployed to recommend the optimal prompt. Our experimental evaluations across various LLM scales and arithmetic reasoning datasets underscore both the efficacy and economic viability of the proposed approach.
Figure 1. No prompt is perfect that works for all queries. The optimal prompt is query-dependent. Yet the seeking of such prompts can be costly and inefficient. Prompt-OIRL optimizes prompt during inference time on a query-dependent level effectively and cost-efficiently. (original chat logs with GPT4 for those motivating examples can be found at Left, Right)
To reproduce our results (e.g., using LLaMA2)
-
get the license to use LLaMA-2.
- Clone the repository
git clone git@github.com:holarissun/Prompt-OIRL.git
- Create a new virtual environment with Python 3.10, e.g.,
conda create --name prompt-oirl python==3.10
conda activate prompt-oirl
cd Prompt-OIRL
- Install the requirements
pip install -r requirements.txt
Step 1. (Optional, as we also released the offline dataset) Generate an offline dataset by interacting with the LLMs.
This step will take a long time --- typically a few days. To avoid repeating such a computationally expensive (when running LLMs on local machines) or costly (when calling the commercial APIs like GPT3.5 or TigerBot) process, we have released all the interactive logs with those LLMs collected in our experiments..
If you would like to reproduce the offline dataset, for example, with the llama2 model, you need to work under the dir of
git@github.com:facebookresearch/llama.git
and move
Prompt-OIRL/llama_exps/llama_step1_gen_offline.py
to the llama
folder
then
torchrun --nproc_per_node 1 llama_step1_gen_offline.py \
--ckpt_dir llama-2-7b-chat/ \
--tokenizer_path tokenizer.model \
--max_seq_len 512 --max_batch_size 8 --prompt_idx 0 --dataset_eval gsm8k
This step will take a few seconds to finish, it will do some file renaming and training-test split and save corresponding files to a new folder LMllama2
python3 llama_step2_reorg_data.py
This step will take a few seconds to finish, it will process the data and store embeddings and labels for different experiment settings (i.e., with different availability of training prompts) with .npy
format files.
python3 llama_step3_data_processing.py
This step will take a few minutes to a few hours to finish, depending on the algorithms chosen and the processor. In general, training an XGBoost reward model will take a bit longer time, and using a LightGBM reward model can be faster.
python3 llama_step4_offline_evaluation.py
This step will take a few minutes to finish. Evaluating the algorithms by interacting with the LLMs can also be an option but could be slower. Results under different settings will be all saved to .csv
files
python3 llama_step5_offline_optimization.py
- Note: you may need to download a missing embedding file from this link. (oversized for Github, ~ 230Mb)
Prompt-OIRL addresses the prompting problems in LLMs using an RLAIF approach. For readers who are also interested in RLHF and RLAIF, and in the intersection between RL and LLM research, we would refer to our related positioning paper discussing RL in LLM research: RL in the Era of LLMs: What is Essential? What is Needed? RLHF, Prompting, and Beyond.
If you would like to cite our code or paper, please use
@inproceedings{sun2023query,
title={Query-Dependent Prompt Evaluation and Optimization with Offline Inverse RL},
author={Sun, Hao and H{\"u}y{\"u}k, Alihan and van der Schaar, Mihaela},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024}
}
@article{sun2023reinforcement,
title={Reinforcement Learning in the Era of LLMs: What is Essential? What is needed? An RL Perspective on RLHF, Prompting, and Beyond},
author={Sun, Hao},
journal={arXiv preprint arXiv:2310.06147},
year={2023}
}