16
16
import lightning .fabric .accelerators as accelerators # avoid circular dependency
17
17
from lightning .fabric .plugins .environments .torchelastic import TorchElasticEnvironment
18
18
from lightning .fabric .utilities .exceptions import MisconfigurationException
19
- from lightning .fabric .utilities .types import _DEVICE
20
19
from lightning .fabric .utilities .imports import _LIGHTNING_XPU_AVAILABLE
20
+ from lightning .fabric .utilities .types import _DEVICE
21
21
22
22
23
23
def _determine_root_gpu_device (gpus : List [_DEVICE ]) -> Optional [_DEVICE ]:
@@ -87,14 +87,17 @@ def _parse_gpu_ids(
87
87
# We know the user requested GPUs therefore if some of the
88
88
# requested GPUs are not available an exception is thrown.
89
89
gpus = _normalize_parse_gpu_string_input (gpus )
90
- gpus = _normalize_parse_gpu_input_to_list (gpus , include_cuda = include_cuda , include_mps = include_mps , include_xpu = include_xpu )
90
+ gpus = _normalize_parse_gpu_input_to_list (
91
+ gpus , include_cuda = include_cuda , include_mps = include_mps , include_xpu = include_xpu
92
+ )
91
93
if not gpus :
92
94
raise MisconfigurationException ("GPUs requested but none are available." )
93
95
94
96
if (
95
97
TorchElasticEnvironment .detect ()
96
98
and len (gpus ) != 1
97
- and len (_get_all_available_gpus (include_cuda = include_cuda , include_mps = include_mps , include_xpu = include_xpu )) == 1
99
+ and len (_get_all_available_gpus (include_cuda = include_cuda , include_mps = include_mps , include_xpu = include_xpu ))
100
+ == 1
98
101
):
99
102
# Omit sanity check on torchelastic because by default it shows one visible GPU per process
100
103
return gpus
@@ -115,7 +118,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in
115
118
return int (s .strip ())
116
119
117
120
118
- def _sanitize_gpu_ids (gpus : List [int ], include_cuda : bool = False , include_mps : bool = False , include_xpu : bool = False ) -> List [int ]:
121
+ def _sanitize_gpu_ids (
122
+ gpus : List [int ], include_cuda : bool = False , include_mps : bool = False , include_xpu : bool = False
123
+ ) -> List [int ]:
119
124
"""Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of
120
125
the GPUs is not available.
121
126
@@ -131,7 +136,9 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:
131
136
"""
132
137
if sum ((include_cuda , include_mps , include_xpu )) == 0 :
133
138
raise ValueError ("At least one gpu type should be specified!" )
134
- all_available_gpus = _get_all_available_gpus (include_cuda = include_cuda , include_mps = include_mps , include_xpu = include_xpu )
139
+ all_available_gpus = _get_all_available_gpus (
140
+ include_cuda = include_cuda , include_mps = include_mps , include_xpu = include_xpu
141
+ )
135
142
for gpu in gpus :
136
143
if gpu not in all_available_gpus :
137
144
raise MisconfigurationException (
@@ -141,7 +148,10 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:
141
148
142
149
143
150
def _normalize_parse_gpu_input_to_list (
144
- gpus : Union [int , List [int ], Tuple [int , ...]], include_cuda : bool , include_mps : bool , include_xpu : bool ,
151
+ gpus : Union [int , List [int ], Tuple [int , ...]],
152
+ include_cuda : bool ,
153
+ include_mps : bool ,
154
+ include_xpu : bool ,
145
155
) -> Optional [List [int ]]:
146
156
assert gpus is not None
147
157
if isinstance (gpus , (MutableSequence , tuple )):
@@ -156,7 +166,9 @@ def _normalize_parse_gpu_input_to_list(
156
166
return list (range (gpus ))
157
167
158
168
159
- def _get_all_available_gpus (include_cuda : bool = False , include_mps : bool = False , include_xpu : bool = False ) -> List [int ]:
169
+ def _get_all_available_gpus (
170
+ include_cuda : bool = False , include_mps : bool = False , include_xpu : bool = False
171
+ ) -> List [int ]:
160
172
"""
161
173
Returns:
162
174
A list of all available GPUs
@@ -166,6 +178,7 @@ def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = Fals
166
178
xpu_gpus = []
167
179
if _LIGHTNING_XPU_AVAILABLE :
168
180
import lightning_xpu .fabric as accelerator_xpu
181
+
169
182
xpu_gpus += accelerator_xpu ._get_all_visible_xpu_devices () if include_xpu else []
170
183
return cuda_gpus + mps_gpus + xpu_gpus
171
184
0 commit comments