Original PyTorch implementation of ORCA proposed in the paper "Cross-Modal Fine-Tuning: Align then Refine". ORCA is developed for effectively solving ML problems in diverse modalities using large-scale pretrained transformers. It adapts to a target task via an align-then-refine workflow: given the target input, ORCA first learns an embedding network that aligns the embedded feature distribution with the pretraining modality. The pretrained model is then fine-tuned on the embedded data to exploit the knowledge shared across modalities.
This repo specifically supports
- transferring RoBERTa and Swin transformers (Hugging Face implementation) to downstream tasks;
- minimizing the l2 distance, Maximum Mean Descrepancy (MMD), or optimal transport dataset distance (OTDD) for distributional alignment;
- replicate experiments on NAS-Bench-360, PDEBench, and OpenML tabular tasks.
The Docker image needed for each task can be found in the configuration files under the ./src/configs
directory. Then, run ./src/startup-hook.sh
to install the dependencies.
- Download required datasets and precomputed language features text_xs.py and text_ys.py (if you are using RoBERTa models) to
./src/datasets
- Run the following command:
python3 ./src/main.py --config ./src/configs/task.yaml
Place the corresponding implementation in ./src/embedders.py
and complete the get_tgt_model
function.
- Add the data loaders to
./src/data_loaders.py
and complete theget_data
function in./src/task_configs.py
. - Add the loss functions and evaluation metrics to
./src/utils.py
and complete theget_metric
function in./src/task_configs.py
. - Modify the
get_config
function in./src/task_configs.py
. - Add the yaml file to
./src/configs
.
If you find this project helpful, please consider citing our paper:
@inproceedings{shen2023orca,
author = {Shen, Junhong and Li, Liam and Dery, Lucio M. and Staten, Corey and Khodak, Mikhail and Neubig, Graham and Talwalkar, Ameet},
title = {Cross-Modal Fine-Tuning: Align then Refine},
publisher = {ICML},
year = {2023},
url = {https://arxiv.org/abs/2302.05738}
}
Thanks!