`GRPOTrainer` with `top_entropy_quntile < 1` causes hang with multi gpu training
Closed this issue · 4 comments
Reproduction
# grpo.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", top_entropy_quantile=0.2)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train()with command:
accelerate launch --num_machines 1 --machine_rank 0 --num_processes 2 --main_process_ip=cccxc577 --main_process_port=14624 --mixed_precision=bf16 --multi_gpu grpo.pythe script hangs and doesn't work and doesn't train:
from what i understand this is because of the gather in the function get_high_entropy_mask
this is probably because non_pad_entropies have different lengths on each process, which could cause hangs when using gather. similar to what stated here https://discuss.pytorch.org/t/dist-all-gather-stuck/156037.
to make sure it is indeed the problem. I tried the line all_non_pad_entropies = self.accelerator.gather(non_pad_entropies[:50]) (which is of course not a viable fix). with this change it does work.
System Info
- Platform: Linux-5.14.0-503.21.1.el9_5.x86_64-x86_64-with-glibc2.34
- Python version: 3.11.12
- TRL version: 0.22.0.dev0+48d7ecc
- PyTorch version: 2.7.1
- accelerator(s): NVIDIA A100-SXM4-80GB, NVIDIA A100-SXM4-80GB
- Transformers version: 4.52.4
- Accelerate version: 1.7.0
- Accelerate config: not found
- Datasets version: 3.6.0
- HF Hub version: 0.32.4
- bitsandbytes version: not installed
- DeepSpeed version: not installed
- Diffusers version: not installed
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: not installed
- PEFT version: not installed
- vLLM version: not installed
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete
this seems like a virus/bot
Thanks for reporting this! It's an interesting bug, I'll take a look.
In the current implementation there’s a potential hang:
non_pad_entropies = entropies[mask.bool()].float()
if non_pad_entropies.numel() == 0:
return torch.zeros_like(entropies, dtype=torch.bool)If any rank hits non_pad_entropies.numel() == 0 and returns early, while other ranks proceed to
self.accelerator.gather(non_pad_entropies), the collective call will block forever.
Good point, are you willing to open a pr to fix it?