Got stuck when training with multiple GPU using dist_train.sh
xiazhongyv opened this issue · 9 comments
All child threads getting stuck when training with multiple GPU using dist_train.sh
With CUDA == 11.3, Pytorch == 1.10
After diagnosis, I found it was stuck at https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/utils/common_utils.py#L166-L171
I modified the code from
dist.init_process_group(
backend=backend,
init_method='tcp://127.0.0.1:%d' % tcp_port,
rank=local_rank,
world_size=num_gpus
)
to
dist.init_process_group(
backend=backend
)
and it worked.
I'm curious why this is so, and if someone else is having the same problem, you can try to do the same.
Thanks. I have the same problem, and I solved it using your method.
@sshaoshuai After you fix bug in this way, the tcp_port is not used actually.
Can you fix it in a more decent way?
Thank you for the bug report. It has been fixed in #784.
Can you help to double check whether it works now?
Thank you for the bug report. It has been fixed in #784.
Can you help to double check whether it works now?
@sshaoshuai Thanks for your work. It's ok now.
For single-machine multi-GPU training, I also modified the local_rank
to rank
in torch.cuda.set_device()
to be able to train properly. Otherwise it throws this error: Duplicate GPU detected : rank 0 and rank 1 both on CUDA device a000
.
Modified:
def init_dist_pytorch(tcp_port, local_rank, backend='nccl'):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
num_gpus = torch.cuda.device_count()
dist.init_process_group(
backend=backend,
)
rank = dist.get_rank()
torch.cuda.set_device(rank % num_gpus)
return num_gpus, rank
@sshaoshuai
torch=1.9.0 cuda=11.1.
Got stuck at dist.init_process_group
and the code is latest....
In other distribued training project having the same code for init_process_group
, it ran successfully. ......
@sshaoshuai torch=1.9.0 cuda=11.1. Got stuck at
dist.init_process_group
and the code is latest.... In other distribued training project having the same code forinit_process_group
, it ran successfully. ......
after I uncomment the lines mentioned in #784 (comment), it works.
I have submitted a new PR to solve this issue in #815.
Please pull the latest master branch if you still get block when training with dist_train.sh.
@sshaoshuai torch=1.9.0 cuda=11.1. Got stuck at
dist.init_process_group
and the code is latest.... In other distribued training project having the same code forinit_process_group
, it ran successfully. ......after I uncomment the lines mentioned in #784 (comment), it works.
So what is the cause of this stuck? I also counter this and will try your way...