Skip to content

Commit d2abb7f

Browse files
committed
fix circle import issue
1 parent 2c62b3f commit d2abb7f

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

src/lightning/fabric/accelerators/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@
2525

2626
from lightning.fabric.utilities.imports import _lightning_xpu_available
2727

28-
_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators"
29-
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
30-
call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE)
3128
if _lightning_xpu_available() and "xpu" not in ACCELERATOR_REGISTRY:
3229
from lightning_xpu.fabric import XPUAccelerator
3330

src/lightning/fabric/utilities/seed.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
min_seed_value = np.iinfo(np.uint32).min
1717
from lightning.fabric.utilities.imports import _lightning_xpu_available
1818

19-
if _lightning_xpu_available():
20-
from lightning_xpu.fabric import XPUAccelerator
21-
2219

2320
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
2421
r"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
@@ -61,8 +58,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
6158
np.random.seed(seed)
6259
torch.manual_seed(seed)
6360
torch.cuda.manual_seed_all(seed)
64-
if _lightning_xpu_available() and XPUAccelerator.is_available():
65-
XPUAccelerator.manual_seed_all(seed)
61+
if _lightning_xpu_available() and torch.xpu.is_available():
62+
torch.xpu.manual_seed_all(seed)
6663

6764
os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
6865

@@ -121,8 +118,8 @@ def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) ->
121118
}
122119
if include_cuda:
123120
states["torch.cuda"] = torch.cuda.get_rng_state_all()
124-
if include_xpu and _lightning_xpu_available() and XPUAccelerator.is_available():
125-
states["torch.xpu"] = XPUAccelerator._collect_rng_states()
121+
if include_xpu and _lightning_xpu_available() and torch.xpu.is_available():
122+
states["torch.xpu"] = torch.xpu.get_rng_state_all()
126123
return states
127124

128125

@@ -133,8 +130,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
133130
# torch.cuda rng_state is only included since v1.8.
134131
if "torch.cuda" in rng_state_dict:
135132
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
136-
if "torch.xpu" in rng_state_dict and _lightning_xpu_available() and XPUAccelerator.is_available():
137-
XPUAccelerator._set_rng_states(rng_state_dict)
133+
if "torch.xpu" in rng_state_dict and _lightning_xpu_available() and torch.xpu.is_available():
134+
torch.xpu.set_rng_states_all(rng_state_dict["torch.xpu"])
138135
np.random.set_state(rng_state_dict["numpy"])
139136
version, state, gauss = rng_state_dict["python"]
140137
python_set_rng_state((version, tuple(state), gauss))

0 commit comments

Comments
 (0)