Skip to content

Commit 7d45eff

Browse files
committed
Disable cache for torch.autocast in amp
1 parent c85660c commit 7d45eff

File tree

1 file changed

+5
-1
lines changed
  • src/lightning/pytorch/plugins/precision

1 file changed

+5
-1
lines changed

src/lightning/pytorch/plugins/precision/amp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,11 @@ def clip_gradients(
112112
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
113113

114114
def autocast_context_manager(self) -> torch.autocast:
115-
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half))
115+
return torch.autocast(
116+
self.device,
117+
dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half),
118+
cache_enabled=False
119+
)
116120

117121
@override
118122
@contextmanager

0 commit comments

Comments
 (0)