Yixin Chen1, Shuai Zhang2, Boran Han2, Tong He2 and Bo Li2,3.
1The Chinese University of Hong Kong, 2Amazon Web Services, 3University of Chicago
CaMML, a lightweight module, is crafted to seamlessly integrate multimodal contextual samples into large models, thereby empowering the model to derive knowledge from analogous, domain-specific, up-to-date information and make grounded inferences.
- Clone this repository and navigate to CaMML folder
git clone camml.git
cd camml
- Install Packages
conda create -n camml python=3.10 -y
conda activate camml
pip install --upgrade pip # enable PEP 660 support
bash install.sh
- Install additional packages for training cases
pip install flash-attn --no-build-isolation
Model | Image Size | LLM | Vision Encoder | CaMML Retrieval Model | CaMML Retrieval Data | Train Data | Finetuning Schedule | Model Download |
---|---|---|---|---|---|---|---|---|
CaMML-7B | 224 | Vicuna-7B-v1.3 | CLIP-ViT-L-14-224px | ImageBind-Huge | ScienceQA train | ScienceQA train | ft_12epochs_2e-5 | checkpoint |
CaMML-13B | 224 | Vicuna-13B-v1.3 | CLIP-ViT-L-14-224px | ImageBind-Huge | ScienceQA train | ScienceQA train | ft_12epochs_2e-5 | checkpoint |
Model | Image Size | LLM | Vision Encoder | CaMML Retrieval Model | CaMML Retrieval Data | Train Data | Finetuning Schedule | Model Download |
---|---|---|---|---|---|---|---|---|
CaMML-7B | 336 | Vicuna-7B-v1.5 | CLIP-ViT-L-14-336px | ImageBind-Huge | LLaVA-v1.5-665K | LLaVA-v1.5-665K | ft_1epoch_2e-5 | checkpoint |
CaMML-13B | 336 | Vicuna-13B-v1.5 | CLIP-ViT-L-14-336px | ImageBind-Huge | LLaVA-v1.5-665K | LLaVA-v1.5-665K | ft_1epoch_2e-5 | checkpoint |
CaMML is finetuned on ScienceQA dataset.
- Follow ScienceQA repo to set up the dataset.
- Prepare the Data in LLaVA-format.
python scripts/convert_sqa_to_llava.py \
convert_to_llava \
--base-dir /path/to/ScienceQA/data/scienceqa \
--prompt-format "QCM-LEPA" \
--split {train,val,minival,test,minitest}
CaMML is instruction-finetuned on LLaVA-1.5-665K dataset. Please follow LLaVA instructions to download the annotation of data llava_v1_5_mix665k.json, and download the images from constituting datasets:
- COCO: train2017
- GQA: images
- OCR-VQA: download script
- TextVQA: train_val_images
- VisualGenome: part1, part2
We build CaMML Retriever upon ImageBind Models and AutoFaiss Index tools. Each data entry is encoded by ImageBind-Huge pre-trained checkpoint and saved using AutoFaiss index. We provide the processed faiss index with corresponding data json file:
For building own customized dataset as source for CaMML Retriever, we provide scripts and examples for generating your own index and embedding:
python scripts/retriever/retriever_embed_llava665k.py
python scripts/retriever/build_autofaiss_index.py
We utilize following models as initialization:
- Vicuna-7B-v1.3
- Vicuna-13B-v1.3
- Vicuna-7B-v1.5
- Vicuna-13B-v1.5
- CLIP-ViT-L-14
- CLIP-ViT-L-14-336px
- LLaVA-MM-Projectors
We follow LLaVA preparation to test on 11 tasks (MME, MMbench, GQA, etc.), and organize the data in ./data/eval
.
Also, we provide evaluation on COCO caption, Flickr30k caption, OKVQA/A-OKVQA, and RefCOCO/+/g visual grounding, please download and add to ./data/eval
.
data
├──llava_665k_vision_flatIP.index
├──llava_665k_memory_metadata.json
├──sqa_vision_flatIP.index
├──sqa_train_post_memory_answer.json
├──llava
│ └── llava_665k
│ ├── coco
│ │ └── train2017
│ ├── gqa
│ │ └── images
│ ├── ocr_vqa
│ │ └── images
│ ├── textvqa
│ │ └── train_images
│ └── vg
│ ├── VG_100K
│ └── VG_100K_2
├──scienceqa
│ ├── images
│ ├── llava_train_QCM-LEPA.json
│ ├── llava_val_QCM-LEPA.json
│ ├── llava_test_QCM-LEPA.json
│ └── llava_test_CQA-A.json
└──eval
├── MME
├── mm-vet
├── mmbench
├── mmbench_cn
├── pope
├── scienceqa
├── seed_bench
├── vizwiz
├── vqav2
├── textvqa
├── gqa
├── cococap
├── flickr30k
├── okvqa
├── aokvqa
├── refcoco
├── refcocop
└── refcocog
Run:
bash scripts/train_camml_7B_sqa.sh
e.g.:
torchrun --nproc_per_node=$GPUS_PER_NODE --master_port=$RANDOM \
llava/train/train_camml_sqa.py \
--deepspeed "zero3.json" \
--model_name_or_path ./checkpoints/vicuna-7b-v1.3 \
--version v1 \
--data_path ./data/scienceqa/llava_train_QCM-LEPA.json \
--image_folder ./data/scienceqa/images/ \
--vision_tower ./checkpoints/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-pretrain-vicuna-7b-v1.3/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--perceiver_hidden_size 768 \
--perceiver_querys 64 \
--perceiver_layers 2 \
--icl_num 1 \
--random_shots_training True \
--image_aspect_ratio pad \
--group_by_modality_length True \
--fp16 True \
--output_dir ./checkpoints/$file \
--num_train_epochs 12 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 False \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 0 \
--lazy_preprocess True
Run:
bash scripts/train_camml_7B_665K.sh
e.g.:
torchrun --nproc_per_node=$GPUS_PER_NODE --master_port=$RANDOM \
camml/train/train_camml.py \
--deepspeed "zero3.json" \
--model_name_or_path ./checkpoints/vicuna-7b-v1.5 \
--version v1 \
--data_path ./data/llava/llava_665k/llava_v1_5_mix665k.json \
--image_folder ./data/llava/llava_665k/ \
--vision_tower ./checkpoints/clip-vit-large-patch14-336 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--perceiver_hidden_size 768 \
--perceiver_querys 128 \
--perceiver_layers 2 \
--random_shots_training True \
--image_aspect_ratio pad \
--group_by_modality_length True \
--mm_projector_type mlp2x_gelu \
--fp16 True \
--output_dir ./checkpoints/$file \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 False \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 0 \
--lazy_preprocess True
By default, we adopt LLaVA-v1.5-665K dataset as our retriever source.
python camml/eval/run_camml.py --query $QUESTION --image-file $IMAGE_PATH
Model | AVG. | IMG | TXT |
---|---|---|---|
CaMML-7B-sqa-FT | 91.32 | 89.24 | 93.21 |
CaMML-13B-sqa-FT | 92.03 | 89.94 | 93.84 |
Model | LLM | VQAv2 | GQA | VizWiz | SQA(I) | TextVQA | POPE | MME | MMBench | MMBench-CN | SEED | MM-Vet |
---|---|---|---|---|---|---|---|---|---|---|---|---|
CaMML-7B | Vicuna-7B | 79.4 | 62.7 | 51.2 | 67.9 | 58.0 | 86.4 | 1506.9 | 66.9 | 60.6 | 60.4 | 32.2 |
CaMML-13B | Vicuna-13B | 80.2 | 63.7 | 57.4 | 72.3 | 59.9 | 86.7 | 1588.7 | 70.2 | 63.6 | 62.3 | 36.4 |
Model | LLM | COCO Cap (CIDEr) | Flickr30K Cap (CIDEr) | OKVQA (Acc) | AOKVQA (MC-Acc) | RefCOCO (Acc) | RefCOCO+ (Acc) | RefCOCOg (Acc) |
---|---|---|---|---|---|---|---|---|
CaMML-7B | Vicuna-7B | 111.4 | 82.7 | 64.7 | 81.1 | 66.6 | 60.3 | 57.6 |
CaMML-13B | Vicuna-13B | 116.8 | 84.5 | 66.3 | 82.0 | 70.6 | 65.9 | 60.5 |
CaMML supports up to 19 tasks, you can find them in scripts/evaluation
.
e.g., Testing ScienceQA finetuning CaMML:
# CaMML-7B, 1 shot
bash scripts/evaluation/sqa_ft_camml.sh camml_7b_sqa_ft 1
# CaMML-13B, 3 shots
bash scripts/evaluation/sqa_ft_camml.sh camml_13b_sqa_ft 3
e.g., Testing instruction-tuning CaMML on VQAv2:
TASK="vqav2" # mme, mmvet, mmbench, etc.
bash scripts/evaluation/${TASK}_camml.sh camml_7b
If you find CaMML useful for your research and applications, please cite using this BibTeX:
@misc{camml,
title={CaMML: Context-Aware Multimodal Learner for Large Models},
author={Yixin Chen and Shuai Zhang and Boran Han and Tong He and Bo Li},
year={2024},
journal={The 62nd Annual Meeting of the Association for Computational Linguistics},
eprint={2401.03149},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
We build this repo upon following codebases: