Got stuck when debugging a multi-process program in PyTorch
zhouzaida opened this issue · 0 comments
zhouzaida commented
Before creating a new issue, please check the FAQ to see if your question is answered there.
Environment data
- debugpy version: 1.8.1 (run
import debugpy; print(debugpy.__version__)
if uncertain) - OS and version: Centos
- Python version (& distribution if applicable, e.g. Anaconda): 3.10.13
- Using VS Code or Visual Studio: VS Code 1.90.2
Actual behavior
When I initiate debugging and press F10, the program gets stuck at the line with DDP (ddp_model = DDP(model, device_ids=[device_id])
)
- launch.json
{
"version": "0.2.0",
"configurations": [
{
"name": "Attach Test",
"type": "debugpy",
"request": "attach",
"listen": {
"port": 5678,
},
"justMyCode": false,
},
],
}
- example.py
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic():
dist.init_process_group("nccl")
rank = dist.get_rank()
print(f"Start running basic DDP example on rank {rank}.")
# create model and move it to GPU with id rank
device_id = rank % torch.cuda.device_count()
model = ToyModel().to(device_id)
import debugpy
debugpy.connect(5678)
debugpy.wait_for_client()
debugpy.breakpoint()
ddp_model = DDP(model, device_ids=[device_id])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_id)
loss_fn(outputs, labels).backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
demo_basic()
- launch command:
torchrun --nproc_per_node=2 example.py
Note: The configuration above was working before, but it suddenly stopped working recently. Additionally, I can debug normally by configuring 'connect' in launch.json and using 'listen' in the code, but this method requires setting different ports to avoid port conflicts, which is not very convenient.
Expected behavior
It will not get stuck.
Steps to reproduce:
- XXX