Skip to content

Commit d22dbd1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6c8572b commit d22dbd1

File tree

2 files changed

+6
-10
lines changed
  • src/lightning/pytorch/plugins/precision
  • tests/tests_pytorch/plugins/precision

2 files changed

+6
-10
lines changed

src/lightning/pytorch/plugins/precision/amp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ def clip_gradients(
113113

114114
def autocast_context_manager(self) -> torch.autocast:
115115
return torch.autocast(
116-
self.device,
117-
dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half),
118-
cache_enabled=False
116+
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half), cache_enabled=False
119117
)
120118

121119
@override

tests/tests_pytorch/plugins/precision/test_amp.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@
1414
from unittest.mock import Mock
1515

1616
import pytest
17+
import torch
18+
from torch import nn
1719
from torch.optim import Optimizer
1820

1921
from lightning.pytorch.plugins import MixedPrecision
20-
from lightning.pytorch.utilities import GradClipAlgorithmType
21-
22-
from torch import nn
23-
import torch
24-
2522
from lightning.pytorch.plugins.precision import MixedPrecision
23+
from lightning.pytorch.utilities import GradClipAlgorithmType
2624

2725

2826
def test_clip_gradients():
@@ -62,7 +60,7 @@ def test_optimizer_amp_scaling_support_in_step_method():
6260
def test_amp_with_no_grad(precision: str):
6361
layer = nn.Linear(2, 1)
6462
x = torch.randn(1, 2)
65-
amp = MixedPrecision(precision=precision, device='cpu')
63+
amp = MixedPrecision(precision=precision, device="cpu")
6664

6765
with amp.autocast_context_manager():
6866
with torch.no_grad():
@@ -72,4 +70,4 @@ def test_amp_with_no_grad(precision: str):
7270

7371
loss.backward()
7472

75-
assert loss.grad_fn is not None
73+
assert loss.grad_fn is not None

0 commit comments

Comments
 (0)