Yuexiang Zhai*, Hao Bai†, Zipeng Lin†, Jiayi Pan†, Shengbang Tong†, Yifei Zhou†
Alane Suhr, Saining Xie, Yann LeCun, Yi Ma, Sergey Levine
*Project Lead, †Equal Contribution.
Paper | Project Page | Wandb Report | Data
- [Aug 7, 2024] We have uploaded a .zip file for the gym_cards environment. If you do not have the corresponding fonts, please consider downloading them.
- [June 7, 2024] We have prepared a template text wrapper to utilize our gym-cards environment in pure text. See examples here.
Our project contains three different codebases:
- A slightly modified version of LLaVA.
- See our
git diff
from the LLaVA branch here.
- See our
- Our original GymCards environment.
- The RL4VLM codebases for both the GymCards and ALFWorld environment.
Our training pipelines consists of two steps:
- Prepare for an SFT checkpoint.
- Check here to download the instruction-following data we prepared for running the initial SFT.
- We provide a template script (adapted from the official finetune.sh for the 1.6-mistral model) for running LLaVA sft. Please remember to set the
--data_path
,--image_folder
, and--output_dir
accordingly. - Please follow the instructions for LLaVA fine-tuning here.
- Our experiments start from the llava-1.6-mistral-7b checkpoint, you are welcome to use any initial models, but no guarantee to achieve a similar performance.
- Running RL using the SFT checkpoint.
-
For GymCards, please use these .sh run scripts.
- Check here for conda environment installation.
- [important] You may change the
num_processes
in config_zero2.yaml to the numbers of GPUs you have. Please make sure the number of GPUs you setCUDA_VISIBLE_DEVICES
in the.sh
file>=
thenum_processes
in config_zero2.yaml. - [important] If you only want to play around with our codebase, rather than reproduce our results. You may also skip the SFT from step 1, and directly use the llava1.6 model
liuhaotian/llava-v1.6-mistral-7b
as your initial model in--model-path
.
-
For ALFWorld please use this run file.
- Check here for conda environment installation.
- The
num_processes
in config_zero2.yaml and the number of GPUs in therun_alf.sh
file should follow the same rule as GymCards. We recommend only using 1 GPU to run ALFWorld, because the time for on-policy data collection largely varies across different GPUs, which may lead to NCCL time out during the synchronization of different threads with multiple GPUs.
-
This project is under the MIT License.
If you find our codebases useful, please consider citing our paper:
@article{zhai2024fine,
title={Fine-Tuning Large Vision-Language Models as Decision-Making Agents via Reinforcement Learning},
author={Zhai, Yuexiang and Bai, Hao and Lin, Zipeng and Pan, Jiayi and Tong, Shengbang and Zhou, Yifei and Suhr, Alane and Xie, Saining and LeCun, Yann and Ma, Yi and Levine, Sergey},
journal={arXiv preprint arXiv:2405.10292},
year={2024}
}
Our codebases adopt LLaVA as a backbone model and apply PPO from this repo for RL fine-tuning. In principle, one may try to adapt our pipeline to different VLM / MLLM backbones and different RL algorithms.