Official implementation for the paper "CLIPTrans: Transferring Visual Knowledge with Pre-trained Models for Multimodal Machine Translation", published at ICCV'23. The aim of the paper is to leverage existing pre-trained models(multilingual mBART and multimodal M-CLIP) for multimodal machine translation. More generally, it proposes a framework(pictured above) for multilingual generative tasks using multimodal data.
Setup the repository with the following commands:
git clone --recursive git@github.com:devaansh100/CLIPTrans.git
conda create cliptrans -f environment.yml
conda activate cliptrans
pip install -r requirements.txt
cd transformers
pip install -e .
All data should be organised in the data/
directory.
Download the images for Flickr30k and the MSCOCO test set in the data/multi30k
folder. You should have three tarballs for the images: flickr30k-images.tar.gz
which will have the training, validation and Test2016 set; test_2017-flickr-images.gz
for the MSCOCO set and images_mscoco.task1.tar.gz
for the MSCOCO images.
Now, simply execute setup_multi30k.sh
to download the text data and organise the folders.
The WIT dataset used in our paper can be downloaded from here into the data/wit
folder. Once downloaded, unzip with tar -xvzf wit_mmt.tar.gz
. The images for each of these can be downloaded by running the following command:
python download_images.py $FOLDER_NAME $SPLIT
where FOLDER_NAME
can be one of [es_fr
, en_ro
, en_fr
, en_es
, en_de
, en_af
, de_es
] and SPLIT
can be one of [train
, test
, valid
]. This will take a while. Also note that the downloading of images depends upon the availability of the image on the hosted service; due to which some variance in the scores is to be expected.
The sharded pre-trained models can be found here. For a particular model, download both the shards and place them in models/multi30k-en-de
(for example). Check the second point below for more details on the flags and the naming scheme. While loading these models, the code will automatically detect the shards and use them. Inference can be run with the command:
python src/main.py --num_gpus 1 --mn multi30k --src_lang en --tgt_lang fr --prefix_length 10 --bs 32 --test_ds 2016 flickr --stage translate --test --lm model_best_test.pth
or for a mulit-GPU setup:
python -m torch.distributed.run --nproc_per_node 4 src/main.py --num_gpus 4 --mn multi30k --src_lang en --tgt_lang fr --prefix_length 10 --bs 32 --test_ds 2016 flickr --stage translate --test --lm model_best_test.pth
Training is done in two stages. To run the first stage(captioning), the following commands can be used, depending on the number of available GPUs:
python src/main.py --num_gpus 1 --mn multi30k --prefix_length 10 --bs 32 --update_count 4 --lr 1e-5 --test_ds 2016 val --stage caption --tgt_lang fr
python -m torch.distributed.run --nproc_per_node 4 src/main.py --num_gpus 4 --mn multi30k --prefix_length 10 --bs 32 --update_count 4 --lr 1e-5 --test_ds 2016 val --stage caption --tgt_lang fr
For stage 2, use the following commands:
python src/main.py --num_gpus 1 --mn multi30k --prefix_length 10 --bs 32 --update_count 4 --lr 1e-5 --test_ds 2016 val --stage translate --tgt_lang fr --lm model_pretrained.pth
python -m torch.distributed.run --nproc_per_node 4 src/main.py --num_gpus 4 --mn multi30k --prefix_length 10 --bs 32 --update_count 4 --lr 1e-5 --test_ds 2016 val --stage translate --tgt_lang fr --lm model_pretrained.pth
Here is a quick guide to some specifics about the flags:
--stage
denotes the training task. There are four choices available which are detailed in the table below. These affect the training task, and inference is modified appropriately:
--stage | CLIP input | mBART input | mBART output |
---|---|---|---|
caption | image | trivial | image caption |
translate | source text | source text | target text |
text_recon | source text | trivial | source text |
triplet | image | source text | target text |
--mn
sets the name of the job. It is used to create a unique model folder where the weights are stored and can be loaded from. The source and target language are appended to this name. It must remain uniform across stage 1 and stage 2.--lm
is the name of the weights file to be loaded(which gets saved in the aforementioned folder). For final results, loadmodel_best_test.pth
. Stage 1 models are saved asmodel_pretrained.pth
. To continue training from a saved model, load its weights and add the flag--ct
.--test_ds
sets the dataset to be used in validation/test. While training, pass2016 val
(Multi30k) orvalid
(WIT). For inference, pass2016 flickr
,2017 flickr
or2017 mscoco
for Multi30k andtest
for WIT. Also add the flag--test
so that only inference is run for a saved model.- To finetune an mBART on a dataset, simply pass
--prefix_length 0
. - To use images in inference, add the flags
--noise_test --mask_prob 0
to the inference command.
If the code and/or method was useful to your work, please consider citing us!
@inproceedings {gupta2023cliptrans,
title={CLIPTrans: Transferring Visual Knowledge with Pre-trained Models for Multimodal Machine Translation},
author={Gupta, Devaansh and Kharbanda, Siddhant and Zhou, Jiawei and Li, Wanhua and Pfister, Hanspeter and Wei, Donglai},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
year={2023}
}