/RL4VLM

Official Repo for Fine-Tuning Large Vision-Language Models as Decision-Making Agents via Reinforcement Learning

Primary LanguageJupyter NotebookMIT LicenseMIT

Fine-Tuning Large Vision-Language Models as Decision-Making Agents via Reinforcement Learning

Yuexiang Zhai*, Hao Bai†, Zipeng Lin†, Jiayi Pan†, Shengbang Tong†, Yifei Zhou

Alane Suhr, Saining Xie, Yann LeCun, Yi Ma, Sergey Levine

*Project Lead, †Equal Contribution.

Teaser

Release:

  • [Aug 7, 2024] We have uploaded a .zip file for the gym_cards environment. If you do not have the corresponding fonts, please consider downloading them.
  • [June 7, 2024] We have prepared a template text wrapper to utilize our gym-cards environment in pure text. See examples here.

Contents:

  1. Codebase Structure
  2. Getting Started
  3. License
  4. Citation
  5. Acknowledgement

Codebase Structure

Our project contains three different codebases:

  1. A slightly modified version of LLaVA.
    • See our git diff from the LLaVA branch here.
  2. Our original GymCards environment.
  3. The RL4VLM codebases for both the GymCards and ALFWorld environment.
    • Check here for instructions on running our algorithm on GymCards.
    • Check here for instructions on running our algorithm on ALFWorld.
    • We provide two different conda environments for GymCards and ALFWorld due to some package discrepancies.

Getting Started

Our training pipelines consists of two steps:

  1. Prepare for an SFT checkpoint.
    • Check here to download the instruction-following data we prepared for running the initial SFT.
    • We provide a template script (adapted from the official finetune.sh for the 1.6-mistral model) for running LLaVA sft. Please remember to set the --data_path, --image_folder, and --output_dir accordingly.
    • Please follow the instructions for LLaVA fine-tuning here.
    • Our experiments start from the llava-1.6-mistral-7b checkpoint, you are welcome to use any initial models, but no guarantee to achieve a similar performance.
  2. Running RL using the SFT checkpoint.
    • For GymCards, please use these .sh run scripts.

      • Check here for conda environment installation.
      • [important] You may change the num_processes in config_zero2.yaml to the numbers of GPUs you have. Please make sure the number of GPUs you set CUDA_VISIBLE_DEVICES in the .sh file >= the num_processes in config_zero2.yaml.
      • [important] If you only want to play around with our codebase, rather than reproduce our results. You may also skip the SFT from step 1, and directly use the llava1.6 model liuhaotian/llava-v1.6-mistral-7b as your initial model in --model-path.
    • For ALFWorld please use this run file.

      • Check here for conda environment installation.
      • The num_processes in config_zero2.yaml and the number of GPUs in the run_alf.sh file should follow the same rule as GymCards. We recommend only using 1 GPU to run ALFWorld, because the time for on-policy data collection largely varies across different GPUs, which may lead to NCCL time out during the synchronization of different threads with multiple GPUs.

License

This project is under the MIT License.

Citation

If you find our codebases useful, please consider citing our paper:

@article{zhai2024fine,
  title={Fine-Tuning Large Vision-Language Models as Decision-Making Agents via Reinforcement Learning},
  author={Zhai, Yuexiang and Bai, Hao and Lin, Zipeng and Pan, Jiayi and Tong, Shengbang and Zhou, Yifei and Suhr, Alane and Xie, Saining and LeCun, Yann and Ma, Yi and Levine, Sergey},
  journal={arXiv preprint arXiv:2405.10292},
  year={2024}
}

Acknowledgement

Our codebases adopt LLaVA as a backbone model and apply PPO from this repo for RL fine-tuning. In principle, one may try to adapt our pipeline to different VLM / MLLM backbones and different RL algorithms.