File tree Expand file tree Collapse file tree 1 file changed +22
-0
lines changed
tests/tests_pytorch/plugins/precision Expand file tree Collapse file tree 1 file changed +22
-0
lines changed Original file line number Diff line number Diff line change 19
19
from lightning .pytorch .plugins import MixedPrecision
20
20
from lightning .pytorch .utilities import GradClipAlgorithmType
21
21
22
+ from torch import nn
23
+ import torch
24
+
25
+ from lightning .pytorch .plugins .precision import MixedPrecision
26
+
22
27
23
28
def test_clip_gradients ():
24
29
"""Test that `.clip_gradients()` is a no-op when clipping is disabled."""
@@ -51,3 +56,20 @@ def test_optimizer_amp_scaling_support_in_step_method():
51
56
52
57
with pytest .raises (RuntimeError , match = "The current optimizer.*does not allow for gradient clipping" ):
53
58
precision .clip_gradients (optimizer , clip_val = 1.0 )
59
+
60
+
61
+ @pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
62
+ def test_amp_with_no_grad (precision : str ):
63
+ layer = nn .Linear (2 , 1 )
64
+ x = torch .randn (1 , 2 )
65
+ amp = MixedPrecision (precision = precision , device = 'cpu' )
66
+
67
+ with amp .autocast_context_manager ():
68
+ with torch .no_grad ():
69
+ _ = layer (x )
70
+
71
+ loss = layer (x ).mean ()
72
+
73
+ loss .backward ()
74
+
75
+ assert loss .grad_fn is not None
You can’t perform that action at this time.
0 commit comments