Does `DDPStrategy` support XLA?
laserkelvin opened this issue · 1 comments
laserkelvin commented
Bug description
When configuring a DDPStrategy
with multiple devices that do not use the torch.cuda
API, we trigger the following exception:
File "/home/hpclee1/rds/hpc-work/.conda/envs/matsciml/lib/python3.10/site-packages/torch/cuda/_utils.py", line 46, in err_fn
raise RuntimeError(
RuntimeError: Tried to instantiate dummy base class Stream
The _setup_model
method of DDPStrategy
triggers this exception, as torch.cuda.stream
is hardcoded if device_ids
are passed. I've reproduced the snippet below, but here is a permalink.
@override
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
A potential solution could be checking the target device, or even just checking torch.cuda.is_available()
for the condition. Removing the torch.cuda.Stream()
call and just using the nullcontext()
functions perfectly fine otherwise.
The snippet provided below relies on an XPUAccelerator
registered here, but I would assume this might trigger for other accelerators as well.
What version are you seeing the problem on?
v2.1, v2.2
How to reproduce the bug
env = pl.plugins.environments.SLURMEnvironment()
ddp = pl.strategies.DDPStrategy(
accelerator="xpu",
cluster_environment=env,
process_group_backend="ccl",
find_unused_parameters=True
)
trainer = pl.Trainer(
strategy=ddp, devices=num_devices, fast_dev_run=100, num_nodes=num_nodes
)
trainer.fit(task, datamodule=dm)
Error messages and logs
File "/home/hpclee1/rds/hpc-work/.conda/envs/matsciml/lib/python3.10/site-packages/torch/cuda/_utils.py", line 46, in err_fn
raise RuntimeError(
RuntimeError: Tried to instantiate dummy base class Stream
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 2.2.1
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0): 2.0.1
#- Python version (e.g., 3.9): 3.10
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: N/A
#- GPU models and configuration: Intel 1550 Data Center GPU Max
#- How you installed Lightning(`conda`, `pip`, source): pip
#- Running environment of LightningApp (e.g. local, cloud): Managed Slurm cluster
More info
No response