14
14
15
15
max_seed_value = np .iinfo (np .uint32 ).max
16
16
min_seed_value = np .iinfo (np .uint32 ).min
17
+ from lightning .fabric .utilities .imports import _LIGHTNING_XPU_AVAILABLE
17
18
19
+ if _LIGHTNING_XPU_AVAILABLE :
20
+ from lightning_xpu .fabric import XPUAccelerator
18
21
19
22
def seed_everything (seed : Optional [int ] = None , workers : bool = False ) -> int :
20
23
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
@@ -56,6 +59,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
56
59
np .random .seed (seed )
57
60
torch .manual_seed (seed )
58
61
torch .cuda .manual_seed_all (seed )
62
+ if XPUAccelerator .is_available ():
63
+ XPUAccelerator .manual_seed_all (seed )
59
64
60
65
os .environ ["PL_SEED_WORKERS" ] = f"{ int (workers )} "
61
66
@@ -104,7 +109,7 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
104
109
random .seed (stdlib_seed )
105
110
106
111
107
- def _collect_rng_states (include_cuda : bool = True ) -> Dict [str , Any ]:
112
+ def _collect_rng_states (include_cuda : bool = True , include_xpu : bool = True ) -> Dict [str , Any ]:
108
113
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
109
114
states = {
110
115
"torch" : torch .get_rng_state (),
@@ -113,6 +118,9 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
113
118
}
114
119
if include_cuda :
115
120
states ["torch.cuda" ] = torch .cuda .get_rng_state_all ()
121
+ if include_xpu :
122
+ if XPUAccelerator .is_available ():
123
+ states ["torch.xpu" ] = XPUAccelerator ._collect_rng_states ()
116
124
return states
117
125
118
126
@@ -123,6 +131,9 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
123
131
# torch.cuda rng_state is only included since v1.8.
124
132
if "torch.cuda" in rng_state_dict :
125
133
torch .cuda .set_rng_state_all (rng_state_dict ["torch.cuda" ])
134
+ if "torch.xpu" in rng_state_dict :
135
+ if XPUAccelerator .is_available ():
136
+ XPUAccelerator ._set_rng_states (rng_state_dict )
126
137
np .random .set_state (rng_state_dict ["numpy" ])
127
138
version , state , gauss = rng_state_dict ["python" ]
128
139
python_set_rng_state ((version , tuple (state ), gauss ))
0 commit comments