File tree Expand file tree Collapse file tree 2 files changed +6
-10
lines changed
src/lightning/pytorch/plugins/precision
tests/tests_pytorch/plugins/precision Expand file tree Collapse file tree 2 files changed +6
-10
lines changed Original file line number Diff line number Diff line change @@ -113,9 +113,7 @@ def clip_gradients(
113
113
114
114
def autocast_context_manager (self ) -> torch .autocast :
115
115
return torch .autocast (
116
- self .device ,
117
- dtype = (torch .bfloat16 if self .precision == "bf16-mixed" else torch .half ),
118
- cache_enabled = False
116
+ self .device , dtype = (torch .bfloat16 if self .precision == "bf16-mixed" else torch .half ), cache_enabled = False
119
117
)
120
118
121
119
@override
Original file line number Diff line number Diff line change 14
14
from unittest .mock import Mock
15
15
16
16
import pytest
17
+ import torch
18
+ from torch import nn
17
19
from torch .optim import Optimizer
18
20
19
21
from lightning .pytorch .plugins import MixedPrecision
20
- from lightning .pytorch .utilities import GradClipAlgorithmType
21
-
22
- from torch import nn
23
- import torch
24
-
25
22
from lightning .pytorch .plugins .precision import MixedPrecision
23
+ from lightning .pytorch .utilities import GradClipAlgorithmType
26
24
27
25
28
26
def test_clip_gradients ():
@@ -62,7 +60,7 @@ def test_optimizer_amp_scaling_support_in_step_method():
62
60
def test_amp_with_no_grad (precision : str ):
63
61
layer = nn .Linear (2 , 1 )
64
62
x = torch .randn (1 , 2 )
65
- amp = MixedPrecision (precision = precision , device = ' cpu' )
63
+ amp = MixedPrecision (precision = precision , device = " cpu" )
66
64
67
65
with amp .autocast_context_manager ():
68
66
with torch .no_grad ():
@@ -72,4 +70,4 @@ def test_amp_with_no_grad(precision: str):
72
70
73
71
loss .backward ()
74
72
75
- assert loss .grad_fn is not None
73
+ assert loss .grad_fn is not None
You can’t perform that action at this time.
0 commit comments