tatsu-lab/alpaca_farm

Running PPO with fewer GPUs

shunzh opened this issue · 3 comments

First, thanks for the implementation! The README says "PPO training currently requires at least 8 80GB GPUs." I was wondering if it's possible to run the PPO algorithm with 4 A100 80GB GPUs. I have tried enabling gradient checkpointing for Llama, which seems not helping. I'm also trying using peft with deepspeed.
I would just like to check if it's possible to run PPO with fewer GPUs at all, and if possible, what changes I should make.

Thanks for your interest! We haven't tested the current implementation on fewer than 8 GPUs for this particular instruction-following setup. But there are several things you may try

  • low precision training, e.g., with bnb
  • quantized training
  • LoRA or other adapters

There are also recent approaches that combine several paradigms above such as QLoRA.

We're working on some code that would make this training much easier, but we don't have a timeline for that yet.

May I know if there is any updates? Thanks a lot.

rtaori commented

Hi,
We don't have an updated timeline for this, sorry about that - it's not currently on our near-term roadmap. We'd be happy to accept any PRs that implement this, but until then closing this issue as it's getting stale.