PythonOT/POT

DistributedDataParallel

dgm2 opened this issue · 2 comments

dgm2 commented

It seems that a DistributedDataParallel (DDP) pytorch setup is not supported in OT - specifically on emd2 computation.
Any workarounds ideas for making this working?
or any example for multi-gpu setups for OT?

ideally, I would like to make OT working with this torch setup
https://github.com/pytorch/examples/blob/main/distributed/ddp/main.py

Many thanks

example of failed DDP

  ot.emd2(a, b, dist)
  File "/python3.8/site-packages/ot/lp/__init__.py", line 468, in emd2
    nx = get_backend(M0, a0, b0)
  File "/python3.8/site-packages/ot/backend.py", line 168, in get_backend
    return TorchBackend()
  File "/python3.8/site-packages/ot/backend.py", line 1517, in __init__
    self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
RuntimeError: CUDA error: all CUDA-capable devices are busy or unavailable

my current workaround is:
changing
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
to
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device=device_id))
passing device id from backend, recompiling this OT from source.

Hello @dgm2 ,

This workaround works? Note that the list is here mainly for debugging and tests (so that we can rub them on all available devices) so I'm a bit surprised if this is the only bottleneck for running POT with DPP.

We are obviously interested in your contribution if you manage to manage it work properly (we don not have multiple GPU so it is a bit hard to implement and debug on our side), probably the device device_id should be detected automatically whene using get_backend and creation, the back-ends should not need parameters to remain practical to use.

Hello @dgm2,
Could you provide us with the exact code you used to get this error ?
I ran https://github.com/pytorch/examples/blob/main/distributed/ddp/main.py with 4 GPUs and ot.emd2 as the loss function, yet did not get any error, everything seems to have run smoothly whether the distribution was performed with torch or slurm.