/assistant-gate

Primary LanguagePythonMIT LicenseMIT

STaR-GATE

This repository contains code for STaR-GATE: Teaching Language Models to Ask Clarifying Questions.

When prompting language models to complete a task, users often leave important aspects unsaid. While asking questions could resolve this ambiguity (GATE; Li et al., 2023), models often struggle to ask good questions. We explore a language model's ability to self-improve (STaR; Zelikman et al., 2022) by rewarding the model for generating useful questions-a simple method we dub STaR-GATE. We generate a synthetic dataset of 25,500 unique persona-task prompts to simulate conversations between a pretrained language model-the Questioner-and a Roleplayer whose preferences are unknown to the Questioner. By asking questions, the Questioner elicits preferences from the Roleplayer. The Questioner is iteratively finetuned on questions that increase the probability of high-quality responses to the task, which are generated by an Oracle with access to the Roleplayer's latent preferences. After two iterations of self-improvement, the Questioner asks better questions, allowing it to generate responses that are preferred over responses from the initial model on 72% of tasks. Our results indicate that teaching a language model to ask better questions leads to better personalized responses.


fig_3

The final model checkpoint for the main experiment is posted here on the HuggingFace hub. Reach out over email or X (linked in my GitHub profile) if you have any questions.

Setup

When creating your conda environment to set up the project, first navigate to the root directory and run the following commands:

  1. pip install -e .
  2. pip install flash-attn --no-build-isolation

Setup local directories

By default, all data and model checkpoints are saved in subdirectories under a directory called /scr/andukuri/assistant-gate-hgx. Make sure you adjust file paths as necessary. The final contents of the directory - corresponding to the paper's results - can be accessed here.

Setup Weights & Biases (training only)

To train models (not just use them), set up a Weights & Biases account for experiment logging, and create a project called assistant-gate. After you create a project and log in to Weights & Biases using the command line interface as described here, all training runs as described below should be logged.

How to do STaR-GATE

Assuming you've reconfigured pointers to directories carefully as described above, and navigated to experiments/star-gate, STaR-GATE can be run end-to-end with a series of shell scripts which point to organized python files. Note that depending on whether you use a SLURM job scheduler (as in the existing scripts) or interact directly with your machine, you may need to adjust the setup in these shell scripts.

One-Time Procedures

  1. Extract initial tasks from source dataset. Run the shell script instruct-questions/scripts/extract-all.sh.
  2. Generate personas few-shot. Run the shell script persona-generation/scripts/generate-personas.sh. Before this step, you should make sure your OpenAI API key has been set in your environment using export OPENAI_API_KEY=<your api key>.
  3. Construct the oracle responses by giving GPT-4 access to both personas and tasks for each split. Run the shell scripts build-gold-responses/scripts/generate-all-gold-responses.sh and build-gold-responses/scripts/generate-all-gold-responses-test.sh. You may have to run build-gold-responses/scripts/check-content-violations.sh beforehand; in some cases GPT-4 generates personas which another copy of itself might consider offensive, so flagging these can help you track down the offending persona.

Important

The above step is very expensive. We make one GPT-4 call for each item in (A) training split A with 250 tasks and 50 personas, (B) training split B with 250 tasks and 50 personas, and (C) the test split with 50 tasks and 10 personas. This step only happens once, so be careful and make sure your directories and output paths are configured correctly so the output oracle responses get saved the first time around.

Expert Iteration

At each iteration $t \in [0, 1, 2]$

  1. Simulate conversations to generate training data. Run the shell script simulate-conversations/scripts/m{t}-{split}.sh, and pool the conversations by running simulate-conversations/scripts/m{t}-{split}-pool.sh. The appropriate split is A for even $t$, and B for odd $t$; we alternate splits to ensure that $m_t$'s high-quality generated conversations are not memorized from the previous iteration.
  2. Calculate log-probabilities of oracle responses to filter best questions. Run the shell script log-probs/scripts/m{t}-{split}-log-probs.sh, and filter the conversations by running log-probs/scripts/m{t}-{split}-filter.sh. In the paper, we keep the top k = 1 conversations out of 10 for each persona-task combination for the training set.
  3. Generate regularizer responses. Run the shell script sft/preprocess/scripts/m{t}-{split}-model-responses.sh.
  4. Preprocess the data for training. Run the shell script sft/preprocess/scripts/m{t}-{split}-split.sh.
  5. Train the initial model $m_0$ to produce the weights $m_{t + 1}$. Run the shell script sft/train/scripts/train-sft-m{t}-{split}.sh.

Evaluation: Oracle Response Log Probabilities

At each iteration $t \in [0, 1, 2, 3]$

  1. Simulate conversations for the test split. Run the shell script simulate-conversations/scripts/m{t}-test.sh, and pool the conversations by running simulate-conversations/scripts/m{t}-test-pool.sh.
  2. Calculate log-probabilities of oracle responses for the test split. Run the shell script log-probs/scripts/m{t}-test-log-probs.sh.

Evaluation: Win Rates

  1. Generate responses from all $m_t$ for $t \in [0, 1, 2, 3]$, conditioned on a randomly sampled conversation for each persona-task combination. Run the shell script response-win-rates-randomized-zero-shot/scripts/get-responses.sh.
  2. Generate win rates by prompting GPT-4 to select the more apt response between $m_t$ and $m_0$. Run the shell script response-win-rates-randomized-zero-shot/scripts/get-ratings.sh.