This library aims to be an allround toolkit for attaching, training, saving and loading of new heads for transformer models.
A new head could be:
- A linear probe used to get an understanding of the information processing in a transformer architecture
- A head to be finetuned jointly with the weights of a pretrained transformer model to perform a completely different kind of task.
- E.g. a transformer pretrained to do causal language modelling could get a sequence classification head attached and be finetuned to do sentiment classification.
- Or one could attach a regression head to turn a large language model into a value function for a reinforcement learning problem.
On top of that, attaching multiple heads at once can make multi-task learning easy, making it possible to train very general models.
Install from pypi: pip install transformer-heads
.
Or, clone this repo and from the root of this repository:
pip install -e .
Create head configurations
head_config = HeadConfig(
name=f"imdb_head_3",
layer_hook=-3, # Attach at the output of the third-to-last transformer-block
in_size=hidden_size,
output_activation="linear",
pred_for_sequence=True,
loss_fct="cross_entropy",
num_outputs=2,
target="label" # The name of the ground-truth column in the dataset
)
Create a model with your head from a pretrained transformer model
model = load_headed(
LlamaForCausalLM,
"meta-llama/Llama-2-7b-hf",
head_configs=[heads_config],
)
Train you model using (for example) the simple to use huggingface Trainer interface:
trainer = Trainer(
model,
args=args,
train_dataset=imdb_dataset["train"],
data_collator=collator,
)
For a more in-depth introduction and a fully working example, check the linear probe notebook.
This repository contains multiple jupyter notebooks for a tutorial/illustration of how do do certain things with this library. Here is an overview of which notebook you should check out depending on the use you are interested in.
- Linear Probes (understanding the inner workings of transformers)
- Basic example with one probe for causal LM: notebooks/gpt2/linear_probe.ipynb
- Train many probes for causal LM at once: notebooks/gpt2/multi_linear_probe.ipynb
- Train many probes for text classification at once: notebooks/gpt2/text_classification_linear_probe.ipynb
- Finetuning on a new type of task (with a new head)
- QLoRA: notebooks/gpt2/text_classification_qlora.ipynb
- Full finetuning: notebooks/gpt2/text_classification_full_finetune.ipynb
- Joint multi-task learning
- Many heads doing completely different tasks + QLoRA, all trained at the same time: notebooks/gpt2/joint_multitask_learning.ipynb
- Regression with pretrained transformers
- Check the regression heads of this notebook: notebooks/gpt2/joint_multitask_learning.ipynb
- Saving and loading
At the state of writing, only a subset of loss functions are supported out of the box. Check transformer_heads/constants.py for more up to date info.
However, it is not so hard to add/use different loss functions/models. You'll just need to add their respective information to loss_fct_map
and model_type_map
. Just import from transformer_heads.constants
. To add a loss function, add a mapping from string to torch class. To add a model add a mapping from model type to a 2 tuple out of attribute name of the base model in the Model Class and Base model class. That may sound confusing, but what that means is just the following:
from transformer_heads.constants import model_type_map, loss_fct_map
import torch.nn as nn
from transformers import MistralModel
loss_fct_map["bce"] = nn.BCELoss()
model_type_map["mistral"] = ("model",MistralModel)
One of the basic assumtions of my library is that there is a transformer class such as the LlamaForCausalLM class of huggingface that has an attribute pointing to a base model that outputs raw hidden state. If your transformers model is built up in a similar way, adding support may be as easy as adding an entry to the model_type_map with the name of the attribute and the class of the base model. You can either do that by importing from constants.py or by adding it directly and creating a pull request.