Lightning-AI/pytorch-lightning

Does `DDPStrategy` support XLA?

laserkelvin opened this issue · 1 comments

Bug description

When configuring a DDPStrategy with multiple devices that do not use the torch.cuda API, we trigger the following exception:

  File "/home/hpclee1/rds/hpc-work/.conda/envs/matsciml/lib/python3.10/site-packages/torch/cuda/_utils.py", line 46, in err_fn
    raise RuntimeError(
RuntimeError: Tried to instantiate dummy base class Stream

The _setup_model method of DDPStrategy triggers this exception, as torch.cuda.stream is hardcoded if device_ids are passed. I've reproduced the snippet below, but here is a permalink.

    @override
    def _setup_model(self, model: Module) -> DistributedDataParallel:
        """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
        device_ids = self.determine_ddp_device_ids()
        log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
        # https://pytorch.org/docs/stable/notes/cuda.html#id5
        ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
        with ctx:
            return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

A potential solution could be checking the target device, or even just checking torch.cuda.is_available() for the condition. Removing the torch.cuda.Stream() call and just using the nullcontext() functions perfectly fine otherwise.

The snippet provided below relies on an XPUAccelerator registered here, but I would assume this might trigger for other accelerators as well.

What version are you seeing the problem on?

v2.1, v2.2

How to reproduce the bug

env = pl.plugins.environments.SLURMEnvironment()
ddp = pl.strategies.DDPStrategy(
    accelerator="xpu",
    cluster_environment=env,
    process_group_backend="ccl",
    find_unused_parameters=True
)

trainer = pl.Trainer(
    strategy=ddp, devices=num_devices, fast_dev_run=100, num_nodes=num_nodes
)
trainer.fit(task, datamodule=dm)

Error messages and logs

  File "/home/hpclee1/rds/hpc-work/.conda/envs/matsciml/lib/python3.10/site-packages/torch/cuda/_utils.py", line 46, in err_fn
    raise RuntimeError(
RuntimeError: Tried to instantiate dummy base class Stream

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 2.2.1
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0): 2.0.1
#- Python version (e.g., 3.9): 3.10
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: N/A
#- GPU models and configuration: Intel 1550 Data Center GPU Max
#- How you installed Lightning(`conda`, `pip`, source): pip
#- Running environment of LightningApp (e.g. local, cloud): Managed Slurm cluster

More info

No response