Skip to content

Commit f7952d4

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2538a56 commit f7952d4

File tree

6 files changed

+41
-21
lines changed

6 files changed

+41
-21
lines changed

src/lightning/fabric/accelerators/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
1818
from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401
1919
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
20+
2021
_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators"
2122
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
2223
call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE)
23-
if _LIGHTNING_XPU_AVAILABLE:
24-
if "xpu" not in ACCELERATOR_REGISTRY:
25-
from lightning_xpu.fabric import XPUAccelerator
26-
XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY)
24+
if _LIGHTNING_XPU_AVAILABLE and "xpu" not in ACCELERATOR_REGISTRY:
25+
from lightning_xpu.fabric import XPUAccelerator
26+
27+
XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY)

src/lightning/fabric/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def _get_num_processes(accelerator: str, devices: str) -> int:
158158
raise ValueError("Launching processes for TPU through the CLI is not supported.")
159159
elif accelerator == "xpu":
160160
from lightning_xpu.fabric import XPUAccelerator
161+
161162
parsed_devices = XPUAccelerator.parse_devices(devices)
162163
else:
163164
return CPUAccelerator.parse_devices(devices)

src/lightning/fabric/connector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from lightning.fabric.accelerators.cuda import CUDAAccelerator
2424
from lightning.fabric.accelerators.mps import MPSAccelerator
2525
from lightning.fabric.accelerators.xla import XLAAccelerator
26-
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
2726
from lightning.fabric.plugins import (
2827
CheckpointIO,
2928
DeepSpeedPrecision,
@@ -64,7 +63,7 @@
6463
from lightning.fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy
6564
from lightning.fabric.utilities import rank_zero_info, rank_zero_warn
6665
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
67-
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
66+
from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _LIGHTNING_XPU_AVAILABLE
6867

6968
_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
7069
_PLUGIN_INPUT = Union[_PLUGIN, str]
@@ -323,6 +322,7 @@ def _choose_auto_accelerator(self) -> str:
323322
return "cuda"
324323
if _LIGHTNING_XPU_AVAILABLE:
325324
from lightning_xpu.fabric import XPUAccelerator
325+
326326
if XPUAccelerator.is_available():
327327
return "xpu"
328328

@@ -336,6 +336,7 @@ def _choose_gpu_accelerator_backend() -> str:
336336
return "cuda"
337337
if _LIGHTNING_XPU_AVAILABLE:
338338
from lightning_xpu.fabric import XPUAccelerator
339+
339340
if XPUAccelerator.is_available():
340341
return "xpu"
341342
raise RuntimeError("No supported gpu backend found!")
@@ -399,6 +400,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
399400
supported_accelerators_str = ["cuda", "gpu", "mps"]
400401
if _LIGHTNING_XPU_AVAILABLE:
401402
from lightning_xpu.fabric import XPUAccelerator
403+
402404
supported_accelerators.append(XPUAccelerator)
403405
supported_accelerators_str.append("xpu")
404406
if isinstance(self._accelerator_flag, tuple(supported_accelerators)) or (

src/lightning/fabric/utilities/device_parser.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import lightning.fabric.accelerators as accelerators # avoid circular dependency
1717
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment
1818
from lightning.fabric.utilities.exceptions import MisconfigurationException
19-
from lightning.fabric.utilities.types import _DEVICE
2019
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
20+
from lightning.fabric.utilities.types import _DEVICE
2121

2222

2323
def _determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]:
@@ -87,14 +87,17 @@ def _parse_gpu_ids(
8787
# We know the user requested GPUs therefore if some of the
8888
# requested GPUs are not available an exception is thrown.
8989
gpus = _normalize_parse_gpu_string_input(gpus)
90-
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)
90+
gpus = _normalize_parse_gpu_input_to_list(
91+
gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu
92+
)
9193
if not gpus:
9294
raise MisconfigurationException("GPUs requested but none are available.")
9395

9496
if (
9597
TorchElasticEnvironment.detect()
9698
and len(gpus) != 1
97-
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)) == 1
99+
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu))
100+
== 1
98101
):
99102
# Omit sanity check on torchelastic because by default it shows one visible GPU per process
100103
return gpus
@@ -115,7 +118,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in
115118
return int(s.strip())
116119

117120

118-
def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False) -> List[int]:
121+
def _sanitize_gpu_ids(
122+
gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False
123+
) -> List[int]:
119124
"""Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of
120125
the GPUs is not available.
121126
@@ -131,7 +136,9 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:
131136
"""
132137
if sum((include_cuda, include_mps, include_xpu)) == 0:
133138
raise ValueError("At least one gpu type should be specified!")
134-
all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)
139+
all_available_gpus = _get_all_available_gpus(
140+
include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu
141+
)
135142
for gpu in gpus:
136143
if gpu not in all_available_gpus:
137144
raise MisconfigurationException(
@@ -141,7 +148,10 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:
141148

142149

143150
def _normalize_parse_gpu_input_to_list(
144-
gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool, include_xpu: bool,
151+
gpus: Union[int, List[int], Tuple[int, ...]],
152+
include_cuda: bool,
153+
include_mps: bool,
154+
include_xpu: bool,
145155
) -> Optional[List[int]]:
146156
assert gpus is not None
147157
if isinstance(gpus, (MutableSequence, tuple)):
@@ -156,7 +166,9 @@ def _normalize_parse_gpu_input_to_list(
156166
return list(range(gpus))
157167

158168

159-
def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False) -> List[int]:
169+
def _get_all_available_gpus(
170+
include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False
171+
) -> List[int]:
160172
"""
161173
Returns:
162174
A list of all available GPUs
@@ -166,6 +178,7 @@ def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = Fals
166178
xpu_gpus = []
167179
if _LIGHTNING_XPU_AVAILABLE:
168180
import lightning_xpu.fabric as accelerator_xpu
181+
169182
xpu_gpus += accelerator_xpu._get_all_visible_xpu_devices() if include_xpu else []
170183
return cuda_gpus + mps_gpus + xpu_gpus
171184

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -447,13 +447,12 @@ def _choose_strategy(self) -> Union[Strategy, str]:
447447
from lightning_habana import SingleHPUStrategy
448448

449449
return SingleHPUStrategy(device=torch.device("hpu"))
450-
if self._accelerator_flag == "xpu":
451-
if not _LIGHTNING_XPU_AVAILABLE:
452-
raise ImportError(
453-
"You have asked for XPU but you miss install related integration."
454-
" Please run `pip install lightning-xpu` or see for further instructions"
455-
" in https://github.com/Lightning-AI/lightning-XPU/."
456-
)
450+
if self._accelerator_flag == "xpu" and not _LIGHTNING_XPU_AVAILABLE:
451+
raise ImportError(
452+
"You have asked for XPU but you miss install related integration."
453+
" Please run `pip install lightning-xpu` or see for further instructions"
454+
" in https://github.com/Lightning-AI/lightning-XPU/."
455+
)
457456
if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, XLAAccelerator):
458457
if self._parallel_devices and len(self._parallel_devices) > 1:
459458
return XLAStrategy.strategy_name

src/lightning/pytorch/trainer/setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
XLAProfiler,
2929
)
3030
from lightning.pytorch.utilities.exceptions import MisconfigurationException
31-
from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE, _LIGHTNING_HABANA_AVAILABLE, _LIGHTNING_XPU_AVAILABLE
31+
from lightning.pytorch.utilities.imports import (
32+
_LIGHTNING_GRAPHCORE_AVAILABLE,
33+
_LIGHTNING_HABANA_AVAILABLE,
34+
_LIGHTNING_XPU_AVAILABLE,
35+
)
3236
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
3337

3438

0 commit comments

Comments
 (0)