Skip to content

Commit 6c8572b

Browse files
committed
Add a test
1 parent 7d45eff commit 6c8572b

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/tests_pytorch/plugins/precision/test_amp.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
from lightning.pytorch.plugins import MixedPrecision
2020
from lightning.pytorch.utilities import GradClipAlgorithmType
2121

22+
from torch import nn
23+
import torch
24+
25+
from lightning.pytorch.plugins.precision import MixedPrecision
26+
2227

2328
def test_clip_gradients():
2429
"""Test that `.clip_gradients()` is a no-op when clipping is disabled."""
@@ -51,3 +56,20 @@ def test_optimizer_amp_scaling_support_in_step_method():
5156

5257
with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
5358
precision.clip_gradients(optimizer, clip_val=1.0)
59+
60+
61+
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
62+
def test_amp_with_no_grad(precision: str):
63+
layer = nn.Linear(2, 1)
64+
x = torch.randn(1, 2)
65+
amp = MixedPrecision(precision=precision, device='cpu')
66+
67+
with amp.autocast_context_manager():
68+
with torch.no_grad():
69+
_ = layer(x)
70+
71+
loss = layer(x).mean()
72+
73+
loss.backward()
74+
75+
assert loss.grad_fn is not None

0 commit comments

Comments
 (0)