File tree Expand file tree Collapse file tree 2 files changed +12
-4
lines changed Expand file tree Collapse file tree 2 files changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -115,13 +115,17 @@ def setup_environment(self) -> None:
115
115
def setup_module (self , module : Module ) -> DistributedDataParallel :
116
116
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
117
117
device_ids = self ._determine_ddp_device_ids ()
118
+ ctx = None
118
119
if self .root_device .type == "cuda" :
119
120
# https://pytorch.org/docs/stable/notes/cuda.html#id5
120
121
ctx = torch .cuda .stream (torch .cuda .Stream ()) if device_ids is not None else nullcontext ()
122
+ if self .root_device .type == "xpu" :
123
+ ctx = torch .xpu .stream (torch .xpu .Stream ()) if device_ids is not None else nullcontext ()
124
+ if ctx is None :
125
+ return DistributedDataParallel (module = module , device_ids = device_ids , ** self ._ddp_kwargs )
126
+ else :
121
127
with ctx :
122
128
return DistributedDataParallel (module = module , device_ids = device_ids , ** self ._ddp_kwargs )
123
- else :
124
- return DistributedDataParallel (module = module , device_ids = device_ids , ** self ._ddp_kwargs )
125
129
126
130
def module_to_device (self , module : Module ) -> None :
127
131
module .to (self .root_device )
Original file line number Diff line number Diff line change @@ -183,13 +183,17 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
183
183
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
184
184
device_ids = self .determine_ddp_device_ids ()
185
185
log .debug (f"setting up DDP model with device ids: { device_ids } , kwargs: { self ._ddp_kwargs } " )
186
+ ctx = None
186
187
if self .root_device .type == "cuda" :
187
188
# https://pytorch.org/docs/stable/notes/cuda.html#id5
188
189
ctx = torch .cuda .stream (torch .cuda .Stream ()) if device_ids is not None else nullcontext ()
190
+ if self .root_device .type == "xpu" :
191
+ ctx = torch .xpu .stream (torch .xpu .Stream ()) if device_ids is not None else nullcontext ()
192
+ if ctx is None :
193
+ return DistributedDataParallel (module = model , device_ids = device_ids , ** self ._ddp_kwargs )
194
+ else :
189
195
with ctx :
190
196
return DistributedDataParallel (module = model , device_ids = device_ids , ** self ._ddp_kwargs )
191
- else :
192
- return DistributedDataParallel (module = model , device_ids = device_ids , ** self ._ddp_kwargs )
193
197
194
198
def setup_distributed (self ) -> None :
195
199
log .debug (f"{ self .__class__ .__name__ } : setting up distributed..." )
You can’t perform that action at this time.
0 commit comments