Skip to content

Commit 9487579

Browse files
pre-commit-ci[bot]jingxu10
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 15855ee commit 9487579

File tree

1 file changed

+5
-6
lines changed
  • src/lightning/fabric/utilities

1 file changed

+5
-6
lines changed

src/lightning/fabric/utilities/seed.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
if _LIGHTNING_XPU_AVAILABLE:
2020
from lightning_xpu.fabric import XPUAccelerator
2121

22+
2223
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
2324
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
2425
sets the following environment variables:
@@ -118,9 +119,8 @@ def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) ->
118119
}
119120
if include_cuda:
120121
states["torch.cuda"] = torch.cuda.get_rng_state_all()
121-
if include_xpu:
122-
if XPUAccelerator.is_available():
123-
states["torch.xpu"] = XPUAccelerator._collect_rng_states()
122+
if include_xpu and XPUAccelerator.is_available():
123+
states["torch.xpu"] = XPUAccelerator._collect_rng_states()
124124
return states
125125

126126

@@ -131,9 +131,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
131131
# torch.cuda rng_state is only included since v1.8.
132132
if "torch.cuda" in rng_state_dict:
133133
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
134-
if "torch.xpu" in rng_state_dict:
135-
if XPUAccelerator.is_available():
136-
XPUAccelerator._set_rng_states(rng_state_dict)
134+
if "torch.xpu" in rng_state_dict and XPUAccelerator.is_available():
135+
XPUAccelerator._set_rng_states(rng_state_dict)
137136
np.random.set_state(rng_state_dict["numpy"])
138137
version, state, gauss = rng_state_dict["python"]
139138
python_set_rng_state((version, tuple(state), gauss))

0 commit comments

Comments
 (0)