Skip to content

Commit 05ab21a

Browse files
abhilash1910jingxu10
authored andcommitted
xpu seeding PR1
1 parent 25ed51d commit 05ab21a

File tree

1 file changed

+12
-1
lines changed
  • src/lightning/fabric/utilities

1 file changed

+12
-1
lines changed

src/lightning/fabric/utilities/seed.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
max_seed_value = np.iinfo(np.uint32).max
1616
min_seed_value = np.iinfo(np.uint32).min
17+
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
1718

19+
if _LIGHTNING_XPU_AVAILABLE:
20+
from lightning_xpu.fabric import XPUAccelerator
1821

1922
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
2023
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
@@ -56,6 +59,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
5659
np.random.seed(seed)
5760
torch.manual_seed(seed)
5861
torch.cuda.manual_seed_all(seed)
62+
if XPUAccelerator.is_available():
63+
XPUAccelerator.manual_seed_all(seed)
5964

6065
os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
6166

@@ -104,7 +109,7 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
104109
random.seed(stdlib_seed)
105110

106111

107-
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
112+
def _collect_rng_states(include_cuda: bool = True,include_xpu: bool = True) -> Dict[str, Any]:
108113
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
109114
states = {
110115
"torch": torch.get_rng_state(),
@@ -113,6 +118,9 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
113118
}
114119
if include_cuda:
115120
states["torch.cuda"] = torch.cuda.get_rng_state_all()
121+
if include_xpu:
122+
if XPUAccelerator.is_available():
123+
states["torch.xpu"] = XPUAccelerator._collect_rng_states()
116124
return states
117125

118126

@@ -123,6 +131,9 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
123131
# torch.cuda rng_state is only included since v1.8.
124132
if "torch.cuda" in rng_state_dict:
125133
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
134+
if "torch.xpu" in rng_state_dict:
135+
if XPUAccelerator.is_available():
136+
XPUAccelerator._set_rng_states(rng_state_dict)
126137
np.random.set_state(rng_state_dict["numpy"])
127138
version, state, gauss = rng_state_dict["python"]
128139
python_set_rng_state((version, tuple(state), gauss))

0 commit comments

Comments
 (0)