Skip to content

Commit 808c0ff

Browse files
committed
update _LIGHTNING_XPU_AVAILABLE to _lightning_xpu_available
1 parent a53fcab commit 808c0ff

File tree

12 files changed

+34
-29
lines changed

12 files changed

+34
-29
lines changed

src/lightning/fabric/accelerators/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401
1717
from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
1818
from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401
19-
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
19+
from lightning.fabric.utilities.imports import _lightning_xpu_available
2020

2121
_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators"
2222
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
2323
call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE)
24-
if _LIGHTNING_XPU_AVAILABLE and "xpu" not in ACCELERATOR_REGISTRY:
24+
if _lightning_xpu_available and "xpu" not in ACCELERATOR_REGISTRY:
2525
from lightning_xpu.fabric import XPUAccelerator
2626

2727
XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY)

src/lightning/fabric/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@
2424
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
2525
from lightning.fabric.strategies import STRATEGY_REGISTRY
2626
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
27-
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
27+
from lightning.fabric.utilities.imports import _lightning_xpu_available
2828

2929
_log = logging.getLogger(__name__)
3030

3131
_CLICK_AVAILABLE = RequirementCache("click")
3232

3333
_SUPPORTED_ACCELERATORS = ["cpu", "gpu", "cuda", "mps", "tpu"]
34-
if _LIGHTNING_XPU_AVAILABLE:
34+
if _lightning_xpu_available:
3535
_SUPPORTED_ACCELERATORS.append("xpu")
3636

3737

src/lightning/fabric/connector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from lightning.fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy
6565
from lightning.fabric.utilities import rank_zero_info, rank_zero_warn
6666
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
67-
from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _LIGHTNING_XPU_AVAILABLE
67+
from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _lightning_xpu_available
6868

6969
_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
7070
_PLUGIN_INPUT = Union[_PLUGIN, str]
@@ -323,7 +323,7 @@ def _choose_auto_accelerator(self) -> str:
323323
return "mps"
324324
if CUDAAccelerator.is_available():
325325
return "cuda"
326-
if _LIGHTNING_XPU_AVAILABLE:
326+
if _lightning_xpu_available:
327327
from lightning_xpu.fabric import XPUAccelerator
328328

329329
if XPUAccelerator.is_available():
@@ -337,7 +337,7 @@ def _choose_gpu_accelerator_backend() -> str:
337337
return "mps"
338338
if CUDAAccelerator.is_available():
339339
return "cuda"
340-
if _LIGHTNING_XPU_AVAILABLE:
340+
if _lightning_xpu_available:
341341
from lightning_xpu.fabric import XPUAccelerator
342342

343343
if XPUAccelerator.is_available():
@@ -399,7 +399,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
399399
if len(self._parallel_devices) <= 1:
400400
supported_accelerators = [CUDAAccelerator, MPSAccelerator]
401401
supported_accelerators_str = ["cuda", "gpu", "mps"]
402-
if _LIGHTNING_XPU_AVAILABLE:
402+
if _lightning_xpu_available:
403403
from lightning_xpu.fabric import XPUAccelerator
404404

405405
supported_accelerators.append(XPUAccelerator)

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@
3333
from lightning.fabric.strategies.registry import _StrategyRegistry
3434
from lightning.fabric.strategies.strategy import _Sharded
3535
from lightning.fabric.utilities.distributed import log
36-
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
36+
from lightning.fabric.utilities.imports import _lightning_xpu_available
3737
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
3838
from lightning.fabric.utilities.seed import reset_seed
3939
from lightning.fabric.utilities.types import _PATH
4040

41-
if _LIGHTNING_XPU_AVAILABLE:
41+
if _lightning_xpu_available:
4242
from lightning_xpu.fabric import XPUAccelerator
4343

4444
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
@@ -495,7 +495,7 @@ def load_checkpoint(
495495

496496
torch.cuda.empty_cache()
497497
with suppress(AttributeError):
498-
if _LIGHTNING_XPU_AVAILABLE:
498+
if _lightning_xpu_available:
499499
XPUAccelerator.teardown()
500500

501501
_, client_state = engine.load_checkpoint(

src/lightning/fabric/strategies/launchers/multiprocessing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222

2323
from lightning.fabric.strategies.launchers.launcher import _Launcher
2424
from lightning.fabric.utilities.apply_func import move_data_to_device
25-
from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _LIGHTNING_XPU_AVAILABLE
25+
from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _lightning_xpu_available
2626
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
2727

2828
if TYPE_CHECKING:
2929
from lightning.fabric.strategies import ParallelStrategy
3030

31-
if _LIGHTNING_XPU_AVAILABLE:
31+
if _lightning_xpu_available:
3232
from lightning_xpu.fabric import XPUAccelerator
3333

3434

src/lightning/fabric/utilities/device_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import lightning.fabric.accelerators as accelerators # avoid circular dependency
1717
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment
1818
from lightning.fabric.utilities.exceptions import MisconfigurationException
19-
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
19+
from lightning.fabric.utilities.imports import _lightning_xpu_available
2020
from lightning.fabric.utilities.types import _DEVICE
2121

2222

@@ -177,7 +177,7 @@ def _get_all_available_gpus(
177177
cuda_gpus = accelerators.cuda._get_all_visible_cuda_devices() if include_cuda else []
178178
mps_gpus = accelerators.mps._get_all_available_mps_gpus() if include_mps else []
179179
xpu_gpus = []
180-
if _LIGHTNING_XPU_AVAILABLE:
180+
if _lightning_xpu_available:
181181
import lightning_xpu.fabric as accelerator_xpu
182182

183183
xpu_gpus += accelerator_xpu._get_all_visible_xpu_devices() if include_xpu else []

src/lightning/fabric/utilities/imports.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,9 @@
3434
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
3535
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
3636

37-
_LIGHTNING_XPU_AVAILABLE = RequirementCache("lightning-xpu")
37+
38+
@functools.lru_cache(maxsize=1)
39+
def _lightning_habana_available() -> bool:
40+
# This is defined as a function instead of a constant to avoid circular imports, because `lightning_habana`
41+
# also imports Lightning
42+
return bool(RequirementCache("lightning-habana")) and _try_import_module("lightning_habana")

src/lightning/fabric/utilities/seed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
max_seed_value = np.iinfo(np.uint32).max
1616
min_seed_value = np.iinfo(np.uint32).min
17-
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
17+
from lightning.fabric.utilities.imports import _lightning_xpu_available
1818

19-
if _LIGHTNING_XPU_AVAILABLE:
19+
if _lightning_xpu_available:
2020
from lightning_xpu.fabric import XPUAccelerator
2121

2222

src/lightning/pytorch/strategies/launchers/multiprocessing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
from lightning.pytorch.strategies.launchers.launcher import _Launcher
3535
from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM
3636
from lightning.pytorch.trainer.states import TrainerFn, TrainerState
37-
from lightning.pytorch.utilities.imports import _LIGHTNING_XPU_AVAILABLE
37+
from lightning.pytorch.utilities.imports import _lightning_xpu_available
3838
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
3939

40-
if _LIGHTNING_XPU_AVAILABLE:
40+
if _lightning_xpu_available:
4141
from lightning_xpu.pytorch import XPUAccelerator
4242

4343
log = logging.getLogger(__name__)

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
_LIGHTNING_COLOSSALAI_AVAILABLE,
6767
_lightning_graphcore_available,
6868
_lightning_habana_available,
69-
_LIGHTNING_XPU_AVAILABLE,
69+
_lightning_xpu_available,
7070
)
7171
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
7272

@@ -351,7 +351,7 @@ def _choose_auto_accelerator(self) -> str:
351351

352352
if HPUAccelerator.is_available():
353353
return "hpu"
354-
if _LIGHTNING_XPU_AVAILABLE:
354+
if _lightning_xpu_available:
355355
from lightning_xpu.pytorch import XPUAccelerator
356356

357357
if XPUAccelerator.is_available():
@@ -368,7 +368,7 @@ def _choose_gpu_accelerator_backend() -> str:
368368
return "mps"
369369
if CUDAAccelerator.is_available():
370370
return "cuda"
371-
if _LIGHTNING_XPU_AVAILABLE:
371+
if _lightning_xpu_available:
372372
from lightning_xpu.pytorch import XPUAccelerator
373373

374374
if XPUAccelerator.is_available():
@@ -448,7 +448,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
448448
from lightning_habana import SingleHPUStrategy
449449

450450
return SingleHPUStrategy(device=torch.device("hpu"))
451-
if self._accelerator_flag == "xpu" and not _LIGHTNING_XPU_AVAILABLE:
451+
if self._accelerator_flag == "xpu" and not _lightning_xpu_available:
452452
raise ImportError(
453453
"You have asked for XPU but you miss install related integration."
454454
" Please run `pip install lightning-xpu` or see for further instructions"
@@ -722,7 +722,7 @@ def _register_external_accelerators_and_strategies() -> None:
722722
if "hpu_single" not in StrategyRegistry:
723723
SingleHPUStrategy.register_strategies(StrategyRegistry)
724724

725-
if _LIGHTNING_XPU_AVAILABLE:
725+
if _lightning_xpu_available:
726726
from lightning_xpu.pytorch import XPUAccelerator
727727

728728
# TODO: Prevent registering multiple times

0 commit comments

Comments
 (0)