16
16
min_seed_value = np .iinfo (np .uint32 ).min
17
17
from lightning .fabric .utilities .imports import _lightning_xpu_available
18
18
19
- if _lightning_xpu_available ():
20
- from lightning_xpu .fabric import XPUAccelerator
21
-
22
19
23
20
def seed_everything (seed : Optional [int ] = None , workers : bool = False ) -> int :
24
21
r"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
@@ -61,8 +58,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
61
58
np .random .seed (seed )
62
59
torch .manual_seed (seed )
63
60
torch .cuda .manual_seed_all (seed )
64
- if _lightning_xpu_available () and XPUAccelerator .is_available ():
65
- XPUAccelerator .manual_seed_all (seed )
61
+ if _lightning_xpu_available () and torch . xpu .is_available ():
62
+ torch . xpu .manual_seed_all (seed )
66
63
67
64
os .environ ["PL_SEED_WORKERS" ] = f"{ int (workers )} "
68
65
@@ -122,8 +119,8 @@ def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) ->
122
119
}
123
120
if include_cuda :
124
121
states ["torch.cuda" ] = torch .cuda .get_rng_state_all ()
125
- if include_xpu and _lightning_xpu_available () and XPUAccelerator .is_available ():
126
- states ["torch.xpu" ] = XPUAccelerator . _collect_rng_states ()
122
+ if include_xpu and _lightning_xpu_available () and torch . xpu .is_available ():
123
+ states ["torch.xpu" ] = torch . xpu . get_rng_state_all ()
127
124
return states
128
125
129
126
@@ -134,8 +131,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
134
131
# torch.cuda rng_state is only included since v1.8.
135
132
if "torch.cuda" in rng_state_dict :
136
133
torch .cuda .set_rng_state_all (rng_state_dict ["torch.cuda" ])
137
- if "torch.xpu" in rng_state_dict and _lightning_xpu_available () and XPUAccelerator .is_available ():
138
- XPUAccelerator . _set_rng_states (rng_state_dict )
134
+ if "torch.xpu" in rng_state_dict and _lightning_xpu_available () and torch . xpu .is_available ():
135
+ torch . xpu . set_rng_states_all (rng_state_dict [ "torch.xpu" ] )
139
136
np .random .set_state (rng_state_dict ["numpy" ])
140
137
version , state , gauss = rng_state_dict ["python" ]
141
138
python_set_rng_state ((version , tuple (state ), gauss ))
0 commit comments