Code for Aligning LLM Agents by Learning Latent Preference from User Edits.
- Installation
- Implementation of PRELUDE Framework
- Reproduce Our Experiments
- Implement Your Own Agents
-
This project is developed in Python 3.6. Using Conda to set up a virtual environment is recommended.
-
Install the required dependencies.
pip install -r requirements.txt
-
Install PyTorch from http://pytorch.org/.
PRELUDE implementation contains the follwoing main concepts task
, user
, and agent
.
Task is the class encapsulating the following:
- Access to dataset which is sequence of the
$(x_t, f^\star_t)$ pairs of (context, true user preference pairs) - Main task prompt (Prompts to generate
$y_t$ given$x_t$ and optionally$f_t$ ):
def get_task_prompt(self, input: str, preference: Optional[str] = None) -> str:
...
- User evaluation prompts (Prompts to generate
$y'_t$ ):
def get_edit_prompts(self, input: str, output: str, preference: str) -> Tuple[str, str]:
...
Right now two different tasks are implemented - content summarization and email writing
Task specifics can be controlled using TaskConfig which allows to:
- Change the number of examples
- Choose random seed
- Specify data source
User encapsulates access to task and LLM resource for simulating user responses. For initialization, TaskConfig and UserConfig (allowing to specify the LLM model name) are required.
Classes responsible for accomplishing the tasks, encapsulating access to LLM and learning algorithm implementations.
All agents mentioned in our paper are located in the agent folder. You can find the insturction and scripts to reproduce our experiments in the experiments folder.
Every agent should be inherited from the base Agent class, and have implementations of the following methods:
def complete(self, text) -> LLMOutput
- task completion method returning LLMOutput object containing output text and (optionally) debug token informationdef learn(self, message, correction: Correction) -> Dict
- learning method taking context text and pair of (agent completion, user edits) as inputs. Return value is the dictionary of metrics required to be logged.
Please check the notebook example of dummy agent implementation and end-to-end experiment run here.