Skip to content

Commit 5ceb86f

Browse files
committed
add torch.xpu.stream to ddp
1 parent 6768bdb commit 5ceb86f

File tree

2 files changed

+12
-4
lines changed
  • src/lightning

2 files changed

+12
-4
lines changed

src/lightning/fabric/strategies/ddp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,17 @@ def setup_environment(self) -> None:
115115
def setup_module(self, module: Module) -> DistributedDataParallel:
116116
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
117117
device_ids = self._determine_ddp_device_ids()
118+
ctx = None
118119
if self.root_device.type == "cuda":
119120
# https://pytorch.org/docs/stable/notes/cuda.html#id5
120121
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
122+
if self.root_device.type == "xpu":
123+
ctx = torch.xpu.stream(torch.xpu.Stream()) if device_ids is not None else nullcontext()
124+
if ctx is None:
125+
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
126+
else:
121127
with ctx:
122128
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
123-
else:
124-
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
125129

126130
def module_to_device(self, module: Module) -> None:
127131
module.to(self.root_device)

src/lightning/pytorch/strategies/ddp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,17 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
183183
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
184184
device_ids = self.determine_ddp_device_ids()
185185
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
186+
ctx = None
186187
if self.root_device.type == "cuda":
187188
# https://pytorch.org/docs/stable/notes/cuda.html#id5
188189
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
190+
if self.root_device.type == "xpu":
191+
ctx = torch.xpu.stream(torch.xpu.Stream()) if device_ids is not None else nullcontext()
192+
if ctx is None:
193+
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
194+
else:
189195
with ctx:
190196
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
191-
else:
192-
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
193197

194198
def setup_distributed(self) -> None:
195199
log.debug(f"{self.__class__.__name__}: setting up distributed...")

0 commit comments

Comments
 (0)