Skip to content

Commit da55ed0

Browse files
committed
fix circle import issue
1 parent 39b7d83 commit da55ed0

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

src/lightning/fabric/accelerators/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@
2525

2626
from lightning.fabric.utilities.imports import _lightning_xpu_available
2727

28-
_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators"
29-
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
30-
call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE)
3128
if _lightning_xpu_available() and "xpu" not in ACCELERATOR_REGISTRY:
3229
from lightning_xpu.fabric import XPUAccelerator
3330

src/lightning/fabric/utilities/seed.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
min_seed_value = np.iinfo(np.uint32).min
1717
from lightning.fabric.utilities.imports import _lightning_xpu_available
1818

19-
if _lightning_xpu_available():
20-
from lightning_xpu.fabric import XPUAccelerator
21-
2219

2320
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
2421
r"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
@@ -61,8 +58,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
6158
np.random.seed(seed)
6259
torch.manual_seed(seed)
6360
torch.cuda.manual_seed_all(seed)
64-
if _lightning_xpu_available() and XPUAccelerator.is_available():
65-
XPUAccelerator.manual_seed_all(seed)
61+
if _lightning_xpu_available() and torch.xpu.is_available():
62+
torch.xpu.manual_seed_all(seed)
6663

6764
os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
6865

@@ -122,8 +119,8 @@ def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) ->
122119
}
123120
if include_cuda:
124121
states["torch.cuda"] = torch.cuda.get_rng_state_all()
125-
if include_xpu and _lightning_xpu_available() and XPUAccelerator.is_available():
126-
states["torch.xpu"] = XPUAccelerator._collect_rng_states()
122+
if include_xpu and _lightning_xpu_available() and torch.xpu.is_available():
123+
states["torch.xpu"] = torch.xpu.get_rng_state_all()
127124
return states
128125

129126

@@ -134,8 +131,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
134131
# torch.cuda rng_state is only included since v1.8.
135132
if "torch.cuda" in rng_state_dict:
136133
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
137-
if "torch.xpu" in rng_state_dict and _lightning_xpu_available() and XPUAccelerator.is_available():
138-
XPUAccelerator._set_rng_states(rng_state_dict)
134+
if "torch.xpu" in rng_state_dict and _lightning_xpu_available() and torch.xpu.is_available():
135+
torch.xpu.set_rng_states_all(rng_state_dict["torch.xpu"])
139136
np.random.set_state(rng_state_dict["numpy"])
140137
version, state, gauss = rng_state_dict["python"]
141138
python_set_rng_state((version, tuple(state), gauss))

0 commit comments

Comments
 (0)