/Domain_Adaptation_LLMs

A library for exploaring the domain adapatation probelm in the alignment of LLMs

Primary LanguagePythonApache License 2.0Apache-2.0

Alignment of LLMs - from a Domain Adaptation Persective

This repo is forked from the HALOs project, but modified serveral components to adapt to our research purpose.

In this project, we're going to explore various domain adaptation methods for the alignment of LLMs. In our setup, the "source domain" referes to the distribution of the preference data $\mathbb{D}={(x_i, y_i^+, y_i^-)}_{i=1}^{N}$ used to train either reward models (in RLHF) or the LLM policies directly (by e.g. DPO). The "target domain", on the other hand, refers to the disitribution over the responses specified by the LLM (a.k.k. the "behavioral policy") which we aim to align with $\mathbb{D}$.

Unified Trainers

At the moment, the two most popular methods for aligning LLMs are

  • Reinforcement Learning from Human Feedback (RLHF): in such methods, we can train a reward model on the preference data $\mathbb{D}$ and use it to guide the training of the LLM policy.
  • Direct Preference Learning (DPL): in such methods, we can train the LLM policy directly on the preference data $\mathbb{D}$, e.g. by DPO or IPO.

Domain Adaptation Methods

The domain adaptation methods we're going to establish include the following ones.

  • Desity Ratio Estimation: in such methods, we can re-weight the source domain data to make it closer to the target domain distribution. To do so, we need to support re-weighting functions which changes the weights of samples when calculating the loss.

  • Pseudo Labeling: in such methods, we can give pseudo labels to the data sampled from the target domain (i.e. the behavioral policy) and use them to align the policy LLM. These method will be implemented in the preference_functions.py later, and the pseudo labels will be generated by the serverl different way as follows.

    • By reward model: we can use a reward model trained on the source domain data to generate pseudo labels for the target domain data. To do so, we need to support the reward model training and inference.
    • By lanugage model: we can prompt a langauge model, e.g. GPT-4, to label the responses generated by the target domain policy. To do so, we need to support the language model inference, or API calls to language model service providers.