rusty1s/pytorch_cluster

Issue running with multi gpu

PabloVD opened this issue · 11 comments

Hi,
I'm trying to train a GNN using two gpus. I need to compute a radius graph using positions pos, as

edge_index = radius_graph(pos, r=self.linkradius, batch=batch, loop=self.loop)

That works well using only one gpu. However, when I try to use two gpus (using model = torch.nn.DataParallel(model), it shows this error message:

edge_index = radius_graph(pos, r=self.linkradius, batch=batch, loop=self.loop)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/homes/pvillanueva/anaconda3/lib/python3.9/site-packages/torch_cluster/radius.py", line 116, in radius_graph

    assert flow in ['source_to_target', 'target_to_source']
    edge_index = radius(x, x, r, batch, batch,
                 ~~~~~~ <--- HERE
                        max_num_neighbors if loop else max_num_neighbors + 1,
                        num_workers)
  File "/homes/pvillanueva/anaconda3/lib/python3.9/site-packages/torch_cluster/radius.py", line 56, in radius
    if batch_x is not None:
        assert x.size(0) == batch_x.numel()
        batch_size = int(batch_x.max()) + 1
                     ~~~ <--- HERE
    if batch_y is not None:
        assert y.size(0) == batch_y.numel()
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.


srun: error: isl-gpu21: task 0: Exited with exit code 1

I checked that both pos and batch are in the same gpu device.
Which may be the source of the problem?
Thanks in advance!

Mh, strange. What is the stack trace when running with CUDA_LAUNCH_BLOCKING=1 as suggested?

I tried that, and now I get an error before calling radius_graph in these tensors:

orientbatch = orient2d[batch]
RuntimeError: CUDA error: device-side assert triggered

I checked that both orient2d and batch are in the same device.

Before the error message, I get many warning messages like these ones:

/opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [119,0,0] Assertion `index >= -si>
/opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [120,0,0] Assertion `index >= -si>
/opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [121,0,0] Assertion `index >= -si>
/opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [122,0,0] Assertion `index >= -si>
/opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [123,0,0] Assertion `index >= -si>
/opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [124,0,0] Assertion `index >= -si>
/opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [125,0,0] Assertion `index >= -si>
/opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [126,0,0] Assertion `index >= -si>
/opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [127,0,0] Assertion `index >= -si>

Can you confirm that batch.max() is smaller than orient2d.size(0)?

I printed batch.min(), batch.max(), orient2d.size(0)), getting for the first gpu:
tensor(0, device='cuda:0') tensor(39, device='cuda:0') 40
and for the second one:
tensor(39, device='cuda:1') tensor(79, device='cuda:1') 40
so in the second gpu, batch.max() is larger than orient2d.size(0).

I use batch size 80, and it splits half of it in each gpu, taking the first 40 in one and the following 40 in the other. The point is that, in the second gpu, batch indexes start from 39 and end up in 79, which seems to be causing the problem.
Do you know if there is a simple way to fix that?

How do you generate batch? In the end, the fix could be as simple as ensuring that your batch starts from zero: batch = batch - batch.min().

I create the batchtensor using torch_geometric.loader.DataLoader, in the usual way:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12, pin_memory=True).

I tried doing batch = batch - batch.min() at the beginning of the forward call of the GNN, and I get the same error,
although now the output for batch.min(), batch.max(), orient2d.size(0)) is:
tensor(0, device='cuda:1') tensor(80, device='cuda:1') 80

This still looks to be an issue with how batch is encoded to me. DataLoader should never create a batch that does not start with zero, so I am a bit confused. Can you shed some light on how you use the outputs of DataLoader?

This is what I get printing batch.min(),batch.max() before passing batchas input to the model, and then within the forward call of the model:
Batch before entering in model: tensor(0, device='cuda:0') tensor(79, device='cuda:0')
Batch within model: model tensor(0, device='cuda:0') tensor(40, device='cuda:1') tensor(40, device='cuda:0') tensor(79, device='cuda:1')

Note that in the second case it prints two times, first the two minima for each gpu, and later the maxima.

So having a batch which starts in >0 seems to be caused by DataParallelwhich splits the batch in two, each for a different gpu.

Here is how I parse the model to DataParallel:

model = GNN(...)
model = torch.nn.DataParallel(model)
model = model.to(device)

I have no experience with multiple gpus, so maybe I'm doing something wrong there?

I think it is safe to close this issue here. We can re-open on PyG side if needed, but here is my recommendation:

  • You need to use DataListLoader in conjunction with DataParallel. Here is an example
  • It is generally not recommended to use DataParallel. The better approach is to utilize DistributedDataParallel which comes with a way nicer interface. Here is an example.

Both paths should resolve your observed issue.

Ok, I'll try both approaches, lets see if any of them works many thanks!

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?