MaTS
This repository contains the official code for the paper: "Merging by Matching Models in Task Subspaces".
Setup
- Create a virtual environment and activate it.
python3.8 -m venv env
source env/bin/activate
- Install dependencies
python -m pip install -r requirements.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html
- Set environment variables (This step has to be done every session)
HUGGINGFACE_HUB_CACHE
path can change depending on your setup.
source bin/setup.sh {DEVICE_ID}
export HUGGINGFACE_HUB_CACHE=~/.cache
Training
Run the training
script with
-c
the list of configs for the model
-td {key}={value}
any training dataset config parameters to update.
-ed {key}={value}
any evaluation dataset config parameters to update.
-tr {key}={value}
any training run config parameters to update.
-er {key}={value}
any evaluation run config parameters to update.
-m {key}={value}
any model config parameters to update.
-w {number of GPUs}
for distributed training
Full-Model Fine-Tuning
T5-Large for Single Task (Paws
)
python src/training.py -c configs/models/t5_large.json configs/training_run/individual_task_T0_run.json configs/training_dataset/p3_individual_task.json configs/evaluation_dataset/p3_validation.json configs/evaluation_run/individual_task.json -td train_dataset=paws -ed evaluation_dataset=paws -tr micro_train_batch_size=16 -er eval_batch_size=32
T5-Large for Multiple Tasks (p3_eight_qa
dataset mixture)
python src/training.py -c configs/models/t5_large.json configs/training_run/p3_eight_qa_T0_run.json configs/training_dataset/p3_eight_qa.json configs/evaluation_dataset/p3_validation.json configs/evaluation_run/p3_eight_qa.json -tr micro_train_batch_size=16 -er eval_batch_size=32
Parameter-Efficient Fine-tuning
Paws
)
python src/training.py -c configs/models/t5_large.json configs/models/ia3.json configs/training_run/individual_task_T0_run_ia3.json configs/training_dataset/p3_individual_task.json configs/evaluation_dataset/p3_validation.json configs/evaluation_run/individual_task.json -td train_dataset=paws -ed evaluation_dataset=paws -tr micro_train_batch_size=16 -er eval_batch_size=32
p3_eight_qa
dataset mixture)
python src/training.py -c configs/models/t5_large.json configs/models/ia3.json configs/training_run/p3_eight_qa_T0_run_ia3.json configs/training_dataset/p3_eight_qa.json configs/evaluation_dataset/p3_validation.json configs/evaluation_run/p3_eight_qa.json -tr micro_train_batch_size=16 -er eval_batch_size=32
Merging
All the examples below are for merging models that were trained on datasets from the p3_eight_qa
dataset mixture.
The results will be saved in
exp_out/merging/{instruction_format}/{dataset_mixture}/{pretrained_model}/{checkpoint_descriptor}/{merging_method}/
For example, conjugate gradients for
exp_out/merging/p3/p3_eight_qa/google-t5-large-lm-adapt/ia3/conjugate_gradients/fisher_merging_blockwise_empirical_fisher_validation_initialize_task_vectors_model_lambda_1.0_iterations_100/
The arguments are the same as the arguments for training, along with the following generic arguments.
--model_lambda
lambda for merging. The exact purpose of lambda depends on the merging method. Default to iterating over values from 0 to 1 with a step size of 0.1
-d {dataset_mixture}
dataset mixture of datasets to evaluate.
--checkpoint_descriptor {checkpoint_descriptor}
key which stores the filepaths of which checkpoint to use for each model. The filepaths are in src/merging/utils.checkpoint_filepaths.py
in a dictionary of the format:
{pretrained_model_name: {
instruction_format: {
checkpoint_decriptor: {
dataset: filepath,
dataset: filepath,
.
.
.
}
}
}}
Methods involving computing some metadata (i.e. diagonal Fisher merging, RegMean, and the conjugate gradient method with the RegMean objective or blockwise Fisher merging objective) also have the following arguments.
--split {split}
which split to compute the Fisher on
--use_true_fisher
whether to use the true or empirical Fisher. Defaults to using the empirical Fisher
-f {fisher_approximation}
which Fisher approximation to use - either diagonal
or blockwise
. Note this argument is only for diagonal Fisher merging or the conjugate gradient method with the blockwise Fisher merging objective.
Evaluating each individual datasets
- Full-Model Fine-tuning
python src/merging/individual_models.py -c configs/models/t5_large.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model -er eval_batch_size=32
$(IA)^3$
python src/merging/individual_models.py -c configs/models/t5_large.json configs/models/ia3.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor ia3 -er eval_batch_size=32
Average
- Full-Model Fine-tuning
python src/merging/average.py -c configs/models/t5_large.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model -er eval_batch_size=32
$(IA)^3$
python src/merging/average.py -c configs/models/t5_large.json configs/models/ia3.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor ia3 -er eval_batch_size=32
Diagonal Fisher Merging
First, we save the diagonal Fisher for each model. Then, we compute the merged model using diagonal Fisher merging.
Save the Diagonal Fisher
- Full-Model Fine-tuning
python src/merging/save_metadata/save_fisher.py -c configs/models/t5_large.json configs/evaluation_run/fisher.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model --split validation -f diagonal
$(IA)^3$
python src/merging/save_metadata/save_fisher.py -c configs/models/t5_large.json configs/models/ia3.json configs/evaluation_run/fisher.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor ia3 --split validation -f diagonal
Merge the models using the Diagonal Fisher
- Full-Model Fine-tuning
python src/merging/diagonal_fisherMerging.py -c configs/models/t5_large.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model -er eval_batch_size=32 --split validation
$(IA)^3$
python src/merging/diagonal_fisherMerging.py -c configs/models/t5_large.json configs/models/ia3.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/flan_validation.json -d flan --checkpoint_descriptor ia3 -er eval_batch_size=32 --split validation
Task Vectors
- Full-Model Fine-tuning
python src/merging/task_vectors.py -c configs/models/t5_large.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model -er eval_batch_size=32
$(IA)^3$
python src/merging/task_vectors.py -c configs/models/t5_large.json configs/models/ia3.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor ia3 -er eval_batch_size=32
TIES
- Full-Model Fine-tuning
python src/merging/ties.py -c configs/models/t5_large.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model -er eval_batch_size=32
$(IA)^3$
python src/merging/ties.py -c configs/models/t5_large.json configs/models/ia3.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor ia3 -er eval_batch_size=32
RegMean
First, we save the Gram matrix for each model. Then, we compute the merged model using RegMean. Note that RegMean cannot be applied to
Saving the Gram Matrix
python src/merging/save_metadata/save_gram_matrix.py -c configs/models/t5_large.json configs/evaluation_run/fisher.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model --split validation
Note that this save the Gram matrix of the input activations and the Gram matrix of the output activation gradients. RegMean only uses the Gram matrix of the input activations. The conjugate gradient method with the blockwise Fisher merging objective uses the Gram matrix of the input activations and the Gram matrix of the output activation gradients.
Merge the models using RegMean
python src/merging/regmean.py -c configs/models/t5_large.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model -er eval_batch_size=32 --split validation
Save Blockwise Fisher
For
- (IA)^3
python src/merging/save_metadata/save_fisher.py -c configs/models/t5_large.json configs/models/ia3.json configs/evaluation_run/fisher.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor ia3 --split validation -f blockwise
- Full-Model FIne-tuning
python src/merging/save_metadata/save_gram_matrix.py -c configs/models/t5_large.json configs/evaluation_run/fisher.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model --split validation
This is the same command for storing the Gram matrices for RegMean.
Conjugate Gradient Method
Average Objective
$(IA)^3$
python src/merging/conjugateGradient_average.py -c configs/models/t5_large.json configs/models/ia3.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor ia3 -er eval_batch_size=32 --split validation --num_iterations 100 --model_lambda 1.0 --initialization average
- Full-Model Fine-tuning
python src/merging/conjugateGradient_average.py -c configs/models/t5_large.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model -er eval_batch_size=32 --split validation --num_iterations 100 --model_lambda 1.0 --initialization average
Diagonal Fisher merging Objective
$(IA)^3$
python src/merging/conjugateGradient_diagonalFisher.py -c configs/models/t5_large.json configs/models/ia3.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor ia3 -er eval_batch_size=32 --split validation --num_iterations 100 --model_lambda 1.0
- Full-Model Fine-tuning
python src/merging/conjugateGradient_diagonalFisher.py -c configs/models/t5_large.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model -er eval_batch_size=32 --split validation --num_iterations 100 --model_lambda 1.0
RegMean Objective
This only holds for full-model fine-tuning.
python src/merging/conjugateGradient_fisherMerging.py -c configs/models/t5_large.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model -er eval_batch_size=32 --split validation --model_lambda 1.0 --num_iterations 100 --initialization exp_out/merging/p3/p3_eight_qa/google-t5-large-lm-adapt/full_model/task_vectors/model_lambda_0.3/merged_model.pt
Fisher Merging Objective
$(IA)^3$
python src/merging/conjugateGradient_fisherMerging.py -c configs/models/t5_large.json configs/models/ia3.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor ia3 -er eval_batch_size=32 --split validation --model_lambda 1.0 --num_iterations 10 --initialization exp_out/merging/p3/p3_eight_qa/google-t5-large-lm-adapt/ia3/task_vectors/model_lambda_0.2/merged_model.pt -f blockwise
- Full Model P3
python src/merging/conjugateGradient_fisherMerging.py -c configs/models/t5_large.json configs/evaluation_run/individual_task.json configs/evaluation_dataset/p3_validation.json -d p3_eight_qa --checkpoint_descriptor full_model -er eval_batch_size=32 --split validation --model_lambda 1.0 --num_iterations 100 --use_backward --initialization exp_out/merging/p3/p3_eight_qa/google-t5-large-lm-adapt/full_model/task_vectors/model_lambda_0.3/merged_model.pt
Evaluation
Run the inference
script with
-c
the config of the experiment with the model
-k
any evaluation config parameters to update.
-e
experiment directory with model to evaluate
--merged_model
merged model to evaluate
python src/inference.py -e exp_out/p3/quartz/facebook-opt-1.3b/full_model/2023-10-21-11-34-24 --checkpoint_idx 299 -ed split=train -er eval_batch_size=32
python src/inference.py --merged_model exp_out/merging/p3/p3_eight_qa/google-t5-large-lm-adapt/ia3/average/merged_model.pt -ed evaluation_split=test -i p3_eight_qa -er eval_batch_size=32
Checkpoints
The models for p3_eight_qa
can be found at this google cloud storage bucket.
This includes
- checkpoints for merging under
exp_out/p3
- multitask trained checkpoint under
exp_out/p3_eight_qa
- merged models using various methods under
exp_out/merging
When downloading models, the directory structure should match the structure in the bucket, with exp_out
under mms
.
Citation
If you find this repo helpful, feel free to cite our work:
@article{tam2023merging,
title={Merging by Matching Models in Task Subspaces},
author={Tam, Derek and Bansal, Mohit and Raffel, Colin},
journal={arXiv preprint arXiv:2312.04339},
year={2023}
}
and the following work whose code we use in our work:
@article{yadav2023resolving,
title={Resolving Interference When Merging Models},
author={Yadav, Prateek and Tam, Derek and Choshen, Leshem and Raffel, Colin and Bansal, Mohit},
journal={arXiv preprint arXiv:2306.01708},
year={2023}
}