16
16
min_seed_value = np .iinfo (np .uint32 ).min
17
17
from lightning .fabric .utilities .imports import _lightning_xpu_available
18
18
19
- if _lightning_xpu_available ():
20
- from lightning_xpu .fabric import XPUAccelerator
21
-
22
19
23
20
def seed_everything (seed : Optional [int ] = None , workers : bool = False ) -> int :
24
21
r"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
@@ -61,8 +58,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
61
58
np .random .seed (seed )
62
59
torch .manual_seed (seed )
63
60
torch .cuda .manual_seed_all (seed )
64
- if _lightning_xpu_available () and XPUAccelerator .is_available ():
65
- XPUAccelerator .manual_seed_all (seed )
61
+ if _lightning_xpu_available () and torch . xpu .is_available ():
62
+ torch . xpu .manual_seed_all (seed )
66
63
67
64
os .environ ["PL_SEED_WORKERS" ] = f"{ int (workers )} "
68
65
@@ -121,8 +118,8 @@ def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) ->
121
118
}
122
119
if include_cuda :
123
120
states ["torch.cuda" ] = torch .cuda .get_rng_state_all ()
124
- if include_xpu and _lightning_xpu_available () and XPUAccelerator .is_available ():
125
- states ["torch.xpu" ] = XPUAccelerator . _collect_rng_states ()
121
+ if include_xpu and _lightning_xpu_available () and torch . xpu .is_available ():
122
+ states ["torch.xpu" ] = torch . xpu . get_rng_state_all ()
126
123
return states
127
124
128
125
@@ -133,8 +130,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
133
130
# torch.cuda rng_state is only included since v1.8.
134
131
if "torch.cuda" in rng_state_dict :
135
132
torch .cuda .set_rng_state_all (rng_state_dict ["torch.cuda" ])
136
- if "torch.xpu" in rng_state_dict and _lightning_xpu_available () and XPUAccelerator .is_available ():
137
- XPUAccelerator . _set_rng_states (rng_state_dict )
133
+ if "torch.xpu" in rng_state_dict and _lightning_xpu_available () and torch . xpu .is_available ():
134
+ torch . xpu . set_rng_states_all (rng_state_dict [ "torch.xpu" ] )
138
135
np .random .set_state (rng_state_dict ["numpy" ])
139
136
version , state , gauss = rng_state_dict ["python" ]
140
137
python_set_rng_state ((version , tuple (state ), gauss ))
0 commit comments