19
19
if _LIGHTNING_XPU_AVAILABLE :
20
20
from lightning_xpu .fabric import XPUAccelerator
21
21
22
+
22
23
def seed_everything (seed : Optional [int ] = None , workers : bool = False ) -> int :
23
24
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
24
25
sets the following environment variables:
@@ -118,9 +119,8 @@ def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) ->
118
119
}
119
120
if include_cuda :
120
121
states ["torch.cuda" ] = torch .cuda .get_rng_state_all ()
121
- if include_xpu :
122
- if XPUAccelerator .is_available ():
123
- states ["torch.xpu" ] = XPUAccelerator ._collect_rng_states ()
122
+ if include_xpu and XPUAccelerator .is_available ():
123
+ states ["torch.xpu" ] = XPUAccelerator ._collect_rng_states ()
124
124
return states
125
125
126
126
@@ -131,9 +131,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
131
131
# torch.cuda rng_state is only included since v1.8.
132
132
if "torch.cuda" in rng_state_dict :
133
133
torch .cuda .set_rng_state_all (rng_state_dict ["torch.cuda" ])
134
- if "torch.xpu" in rng_state_dict :
135
- if XPUAccelerator .is_available ():
136
- XPUAccelerator ._set_rng_states (rng_state_dict )
134
+ if "torch.xpu" in rng_state_dict and XPUAccelerator .is_available ():
135
+ XPUAccelerator ._set_rng_states (rng_state_dict )
137
136
np .random .set_state (rng_state_dict ["numpy" ])
138
137
version , state , gauss = rng_state_dict ["python" ]
139
138
python_set_rng_state ((version , tuple (state ), gauss ))
0 commit comments