/mirage

πŸ”₯ Official PyTorch Model "Visual Haystacks: A Vision-Centric Needle-In-A-Haystack Benchmark"

Primary LanguagePythonMIT LicenseMIT

MIRAGE: Multi-Image Retrieval Augmented GEneralization

MIT license arXiv

Welcome to the official repository for our paper: Visual Haystacks: A Vision-Centric Needle-In-A-Haystack Benchmark. Explore our project page here and the benchmark toolkits here!

Authors: Tsung-Han Wu, Giscard Biamby, Jerome Quenum, Ritwik Gupta, Joseph E. Gonzalez, Trevor Darrell, David M. Chan at UC Berkeley.

Visual Haystacks (VHs) Benchmark Dataset: πŸ€— tsunghanwu/visual_haystacks, πŸ™ Github Repo

Model Checkpoint: πŸ€—tsunghanwu/mirage-llama3.1-8.3B

πŸš€ Introduction

This paper addresses the challenge of answering questions across tens of thousands of images. Through extensive experiments through our Visual Haystacks (VHs) benchmark, we demonstrated that existing Large Multimodal Models (LMMs) struggle with inputs exceeding 100 images due to API limitations, context overflow, or hardware constraints on 4 A100 GPUs. These models often face issues such as visual distractions, cross-image reasoning difficulties, and positional biases. To overcome these challenges, we developed MIRAGE (8.3B), a pioneering, open-source visual-RAG baseline model based on LMMs capable of handling tens of thousands of images. In brief, MIRAGE integrates a compressor module that reduces image tokens by 18x, a dynamic query-aware retriever to filter irrelevant images, and a custom-trained LMM that can do multi-image reasoning. MIRAGE sets a new standard in open-source performance on the Visual Haystacks (VHs) benchmark and delivers solid results on both single- and multi-image question answering tasks.

πŸ”§ Installation Guide

  1. Clone this repository and navigate to mirage folder
git clone https://github.com/visual-haystacks/mirage.git
cd mirage
  1. Install Package
conda create -n mirage python=3.10 -y
conda activate mirage
pip install --upgrade pip  # enable PEP 660 support
pip install -e .
  1. Install additional packages for training cases
pip install -e ".[train]"
pip install flash-attn --no-build-isolation --no-cache-dir

βš™οΈ Quick Start / Demo

  • Model Checkpoint: πŸ€—tsunghanwu/mirage-llama3.1-8.3B
  • Demo code (single test case): CUDA_VISIBLE_DEVICES=X python3 demo.py --model-path [huggingface model id or local path] --image-folder [local image folder] --prompt [prompt path]
  • Here’s a sample output from MIRAGE using some photos I took on my iPhone. (Feel free to give it a star if you think my cat is adorable! 😺✨)

πŸ“ˆ Evaluation

1. Data Preparation

In summary, the data structure of playground/data/eval should look like this:

Show Data Structure
playground/data/eval/
β”œβ”€β”€ gqa
β”‚   β”œβ”€β”€ answers
β”‚   β”œβ”€β”€ data                   # directory
β”‚   β”œβ”€β”€ llava_gqa_testdev_balanced.jsonl
β”œβ”€β”€ mmbench
β”‚   β”œβ”€β”€ answers
β”‚   β”œβ”€β”€ answers_upload
β”‚   └── mmbench_dev_20230712.tsv
β”œβ”€β”€ mmbench_cn
β”‚   β”œβ”€β”€ answers
β”‚   β”œβ”€β”€ answers_upload
β”‚   └── mmbench_dev_cn_20231003.tsv
β”œβ”€β”€ mm-vet
β”‚   β”œβ”€β”€ answers
β”‚   β”œβ”€β”€ images                  # directory
β”‚   β”œβ”€β”€ llava-mm-vet.jsonl
β”‚   └── results
β”œβ”€β”€ pope
β”‚   β”œβ”€β”€ answers
β”‚   β”œβ”€β”€ coco                    # directory (point to COCO2014)
β”‚   └── llava_pope_test.jsonl
β”œβ”€β”€ retvqa
β”‚   β”œβ”€β”€ answers
β”‚   β”œβ”€β”€ vg                     # directory (point to Visual Genome directory)
β”‚   └── retvqa_test_mirage.json
β”œβ”€β”€ textvqa
β”‚   β”œβ”€β”€ answers
β”‚   β”œβ”€β”€ llava_textvqa_val_v051_ocr.jsonl
β”‚   β”œβ”€β”€ TextVQA_0.5.1_val.json
β”‚   └── train_images           # directory (download from their website)
β”œβ”€β”€ visual_haystacks
β”‚   β”œβ”€β”€ coco             # directory (point to COCO2017)
β”‚   └── VHs_qa           # directory (download from VHs' huggingface)
β”œβ”€β”€ vizwiz
β”‚   β”œβ”€β”€ answers
β”‚   β”œβ”€β”€ answers_upload
β”‚   β”œβ”€β”€ llava_test.jsonl
β”‚   └── test                   # directory (download from their website)
└── vqav2
    β”œβ”€β”€ answers
    β”œβ”€β”€ answers_upload
    β”œβ”€β”€ llava_vqav2_mscoco_test2015.jsonl
    β”œβ”€β”€ llava_vqav2_mscoco_test-dev2015.jsonl
    └── test2015               # directory (download from their website)
    

2. Run Scripts

# Visual Haystacks
CUDA_VISIBLE_DEVICES=0 bash scripts/eval/{vhs_single,vhs_multi}.sh
# VQAv2, GQA, RetVQA
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/eval/{vqav2,gqa,retvqa}.sh
# Vizwiz, TextVQA, POPE, MMBench, MMBench-CN, MM-Vet
CUDA_VISIBLE_DEVICES=0 bash scripts/eval/{vizwiz,textvqa,pope,mmbench,mmbench_cn,mmvet}.sh

3. Results

Results of Visual Haystacks

Checkpoint VQAv2 GQA VizWiz TextVQA POPE MM-Bench MM-Bench-CN MM-Vet
πŸ€— tsunghanwu/mirage-llama3.1-8.3B 76.56 59.13 40.52 56.24 85.40 69.24 66.92 33.4

πŸ”₯ Training

1. Data Preparation

  • Please download the dataset from πŸ€— tsunghanwu/MIRAGE-training-set.
  • For stage-1 pre-training (training Q-Former and MLP projector), download datasets such as CC-12M, LAION-400M, and COCO.
  • For stages 2 and 3 pre-training, which involve training Q-Former/MLP projector with high-quality captions and training the downstream retriever module with augmented LLaVA data, download SAM, VG, COCO, TextVQA, OCR_VQA, and GQA to playground/data.
  • For instruction tuning, download SAM, VG, COCO, TextVQA, OCR_VQA, GQA, slidevqa, and webqa to playground/data.

Below is the expected data structure for playground/data/eval:

Show Data Structure
playground/data/
β”œβ”€β”€ coco
β”‚   β”œβ”€β”€ annotations
β”‚   β”œβ”€β”€ test2017
β”‚   β”œβ”€β”€ train2017
β”‚   └── val2017
β”œβ”€β”€ gqa
β”‚   └── images
β”œβ”€β”€ ocr_vqa
β”‚   └── images
β”œβ”€β”€ sam
β”‚   └── images 
β”œβ”€β”€ share_textvqa
β”‚   └── images
β”œβ”€β”€ slidevqa
β”‚   └── images (download from https://drive.google.com/file/d/11bsX48cPpzCfPBnYJgSesvT7rWc84LpH/view)
β”œβ”€β”€ textvqa
β”‚   └── train_images
β”œβ”€β”€ vg
β”‚   β”œβ”€β”€ VG_100K
β”‚   └── VG_100K_2
└── webqa
    └── webqa_images (download from https://drive.google.com/drive/folders/1ApfD-RzvJ79b-sLeBx1OaiPNUYauZdAZ and convert them to .jpg format)
    

2. Pretraining/Finetuning

Run the following script with minor modifications as needed. Note: During the finetuning, we found that freezing the downstream retriever but only updating Q-Former/LLM leads to better performance on LLama-3.1-8b, whereas unfreezing the retriever yields better results on vicuna-v1.5-7b.

# Stage 1-3 Pretraining
bash scripts/pretrain_stage{1,2,3}.sh
# Instruction Finetuning
bash scripts/finetune_qformer_lora.sh

3. Weight Merging

Please merge LoRA weights back to the original checkpoint using the following code:

python scripts/merge_lora_weights.py \
    --model-path checkpoints/mirage_qformer_ft \
    --model-base meta-llama/Meta-Llama-3.1-8B-Instruct \
    --save-model-path your_output_path

πŸ™ Acknowledgements

We are grateful for the foundational code provided by LLaVA and LLaVA-More. Utilizing their resources implies agreement to their respective licenses. Our project benefits greatly from these contributions, and we acknowledge their significant impact on our work.

🎯 Citation

If you use our work or our implementation in this repo, or find them helpful, please consider giving a citation.

@article{wu2024visual,
  title={Visual Haystacks: A Vision-Centric Needle-In-A-Haystack Benchmark},
  author={Wu, Tsung-Han and Biamby, Giscard and and Quenum, Jerome and Gupta, Ritwik and Gonzalez, Joseph E and Darrell, Trevor and Chan, David M},
  journal={arXiv preprint arXiv:2407.13766},
  year={2024}
}