Skip to content

Commit e03a1ec

Browse files
committed
1. remove xpu.py from _graveyard
2. correct _lightning_xpu_available() usage
1 parent bf1bd36 commit e03a1ec

File tree

12 files changed

+34
-57
lines changed

12 files changed

+34
-57
lines changed

src/lightning/fabric/accelerators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators"
2222
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
2323
call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE)
24-
if _lightning_xpu_available and "xpu" not in ACCELERATOR_REGISTRY:
24+
if _lightning_xpu_available() and "xpu" not in ACCELERATOR_REGISTRY:
2525
from lightning_xpu.fabric import XPUAccelerator
2626

2727
XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY)

src/lightning/fabric/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_CLICK_AVAILABLE = RequirementCache("click")
3232

3333
_SUPPORTED_ACCELERATORS = ["cpu", "gpu", "cuda", "mps", "tpu"]
34-
if _lightning_xpu_available:
34+
if _lightning_xpu_available():
3535
_SUPPORTED_ACCELERATORS.append("xpu")
3636

3737

src/lightning/fabric/connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def _choose_auto_accelerator(self) -> str:
323323
return "mps"
324324
if CUDAAccelerator.is_available():
325325
return "cuda"
326-
if _lightning_xpu_available:
326+
if _lightning_xpu_available():
327327
from lightning_xpu.fabric import XPUAccelerator
328328

329329
if XPUAccelerator.is_available():
@@ -337,7 +337,7 @@ def _choose_gpu_accelerator_backend() -> str:
337337
return "mps"
338338
if CUDAAccelerator.is_available():
339339
return "cuda"
340-
if _lightning_xpu_available:
340+
if _lightning_xpu_available():
341341
from lightning_xpu.fabric import XPUAccelerator
342342

343343
if XPUAccelerator.is_available():
@@ -399,7 +399,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
399399
if len(self._parallel_devices) <= 1:
400400
supported_accelerators = [CUDAAccelerator, MPSAccelerator]
401401
supported_accelerators_str = ["cuda", "gpu", "mps"]
402-
if _lightning_xpu_available:
402+
if _lightning_xpu_available():
403403
from lightning_xpu.fabric import XPUAccelerator
404404

405405
supported_accelerators.append(XPUAccelerator)

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from lightning.fabric.utilities.seed import reset_seed
3939
from lightning.fabric.utilities.types import _PATH
4040

41-
if _lightning_xpu_available:
41+
if _lightning_xpu_available():
4242
from lightning_xpu.fabric import XPUAccelerator
4343

4444
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
@@ -495,7 +495,7 @@ def load_checkpoint(
495495

496496
torch.cuda.empty_cache()
497497
with suppress(AttributeError):
498-
if _lightning_xpu_available:
498+
if _lightning_xpu_available():
499499
XPUAccelerator.teardown()
500500

501501
_, client_state = engine.load_checkpoint(
@@ -596,7 +596,12 @@ def _initialize_engine(
596596
return deepspeed_engine, deepspeed_optimizer
597597

598598
def _setup_distributed(self) -> None:
599-
if not isinstance(self.accelerator, CUDAAccelerator) and not isinstance(self.accelerator, XPUAccelerator):
599+
ds_support = False
600+
if isinstance(self.accelerator, CUDAAccelerator):
601+
ds_support = True
602+
if _lightning_xpu_available() and isinstance(self.accelerator, XPUAccelerator):
603+
ds_support = True
604+
if not ds_support:
600605
raise RuntimeError(
601606
"The DeepSpeed strategy is only supported on CUDA/Intel(R) GPUs but"
602607
" `{self.accelerator.__class__.__name__}` is used."

src/lightning/fabric/strategies/launchers/multiprocessing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
if TYPE_CHECKING:
2929
from lightning.fabric.strategies import ParallelStrategy
3030

31-
if _lightning_xpu_available:
31+
if _lightning_xpu_available():
3232
from lightning_xpu.fabric import XPUAccelerator
3333

3434

@@ -90,7 +90,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
9090
"""
9191
if self._start_method in ("fork", "forkserver"):
9292
_check_bad_cuda_fork()
93-
if XPUAccelerator.is_available():
93+
if _lightning_xpu_available() and XPUAccelerator.is_available():
9494
_check_bad_xpu_fork()
9595

9696
# The default cluster environment in Lightning chooses a random free port number

src/lightning/fabric/utilities/device_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _get_all_available_gpus(
177177
cuda_gpus = accelerators.cuda._get_all_visible_cuda_devices() if include_cuda else []
178178
mps_gpus = accelerators.mps._get_all_available_mps_gpus() if include_mps else []
179179
xpu_gpus = []
180-
if _lightning_xpu_available:
180+
if _lightning_xpu_available():
181181
import lightning_xpu.fabric as accelerator_xpu
182182

183183
xpu_gpus += accelerator_xpu._get_all_visible_xpu_devices() if include_xpu else []

src/lightning/fabric/utilities/seed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
min_seed_value = np.iinfo(np.uint32).min
1717
from lightning.fabric.utilities.imports import _lightning_xpu_available
1818

19-
if _lightning_xpu_available:
19+
if _lightning_xpu_available():
2020
from lightning_xpu.fabric import XPUAccelerator
2121

2222

@@ -61,7 +61,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
6161
np.random.seed(seed)
6262
torch.manual_seed(seed)
6363
torch.cuda.manual_seed_all(seed)
64-
if XPUAccelerator.is_available():
64+
if _lightning_xpu_available() and XPUAccelerator.is_available():
6565
XPUAccelerator.manual_seed_all(seed)
6666

6767
os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
@@ -121,7 +121,7 @@ def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) ->
121121
}
122122
if include_cuda:
123123
states["torch.cuda"] = torch.cuda.get_rng_state_all()
124-
if include_xpu and XPUAccelerator.is_available():
124+
if include_xpu and _lightning_xpu_available() and XPUAccelerator.is_available():
125125
states["torch.xpu"] = XPUAccelerator._collect_rng_states()
126126
return states
127127

@@ -133,7 +133,7 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
133133
# torch.cuda rng_state is only included since v1.8.
134134
if "torch.cuda" in rng_state_dict:
135135
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
136-
if "torch.xpu" in rng_state_dict and XPUAccelerator.is_available():
136+
if "torch.xpu" in rng_state_dict and _lightning_xpu_available() and XPUAccelerator.is_available():
137137
XPUAccelerator._set_rng_states(rng_state_dict)
138138
np.random.set_state(rng_state_dict["numpy"])
139139
version, state, gauss = rng_state_dict["python"]

src/lightning/pytorch/_graveyard/xpu.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

src/lightning/pytorch/strategies/launchers/multiprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from lightning.pytorch.utilities.imports import _lightning_xpu_available
3838
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
3939

40-
if _lightning_xpu_available:
40+
if _lightning_xpu_available():
4141
from lightning_xpu.pytorch import XPUAccelerator
4242

4343
log = logging.getLogger(__name__)

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def _choose_auto_accelerator(self) -> str:
351351

352352
if HPUAccelerator.is_available():
353353
return "hpu"
354-
if _lightning_xpu_available:
354+
if _lightning_xpu_available():
355355
from lightning_xpu.pytorch import XPUAccelerator
356356

357357
if XPUAccelerator.is_available():
@@ -368,7 +368,7 @@ def _choose_gpu_accelerator_backend() -> str:
368368
return "mps"
369369
if CUDAAccelerator.is_available():
370370
return "cuda"
371-
if _lightning_xpu_available:
371+
if _lightning_xpu_available():
372372
from lightning_xpu.pytorch import XPUAccelerator
373373

374374
if XPUAccelerator.is_available():
@@ -448,7 +448,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
448448
from lightning_habana import SingleHPUStrategy
449449

450450
return SingleHPUStrategy(device=torch.device("hpu"))
451-
if self._accelerator_flag == "xpu" and not _lightning_xpu_available:
451+
if self._accelerator_flag == "xpu" and not _lightning_xpu_available():
452452
raise ImportError(
453453
"You have asked for XPU but you miss install related integration."
454454
" Please run `pip install lightning-xpu` or see for further instructions"
@@ -722,7 +722,7 @@ def _register_external_accelerators_and_strategies() -> None:
722722
if "hpu_single" not in StrategyRegistry:
723723
SingleHPUStrategy.register_strategies(StrategyRegistry)
724724

725-
if _lightning_xpu_available:
725+
if _lightning_xpu_available():
726726
from lightning_xpu.pytorch import XPUAccelerator
727727

728728
# TODO: Prevent registering multiple times

0 commit comments

Comments
 (0)