From 7d45eff9f941054126d8f57aa8c9449f5e983cf4 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 20 Jun 2025 12:24:03 +0200 Subject: [PATCH 01/11] Disable cache for torch.autocast in amp --- src/lightning/pytorch/plugins/precision/amp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 75e792af46b90..10a5fe470718d 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -112,7 +112,11 @@ def clip_gradients( super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) def autocast_context_manager(self) -> torch.autocast: - return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)) + return torch.autocast( + self.device, + dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half), + cache_enabled=False + ) @override @contextmanager From 6c8572b291e8a7cae64dd9da6d38cdf2597fc517 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 20 Jun 2025 15:36:44 +0200 Subject: [PATCH 02/11] Add a test --- .../plugins/precision/test_amp.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index cb061c540b2be..a94c02e8642b6 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -19,6 +19,11 @@ from lightning.pytorch.plugins import MixedPrecision from lightning.pytorch.utilities import GradClipAlgorithmType +from torch import nn +import torch + +from lightning.pytorch.plugins.precision import MixedPrecision + def test_clip_gradients(): """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" @@ -51,3 +56,20 @@ def test_optimizer_amp_scaling_support_in_step_method(): with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): precision.clip_gradients(optimizer, clip_val=1.0) + + +@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"]) +def test_amp_with_no_grad(precision: str): + layer = nn.Linear(2, 1) + x = torch.randn(1, 2) + amp = MixedPrecision(precision=precision, device='cpu') + + with amp.autocast_context_manager(): + with torch.no_grad(): + _ = layer(x) + + loss = layer(x).mean() + + loss.backward() + + assert loss.grad_fn is not None \ No newline at end of file From d18fb086275723561309a8b01efb899cf369accd Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 20 Jun 2025 15:44:24 +0200 Subject: [PATCH 03/11] pre-commit --- tests/tests_pytorch/plugins/precision/test_amp.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index a94c02e8642b6..b297e5ea9e735 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,15 +14,12 @@ from unittest.mock import Mock import pytest -from torch.optim import Optimizer - -from lightning.pytorch.plugins import MixedPrecision -from lightning.pytorch.utilities import GradClipAlgorithmType - -from torch import nn import torch +from torch import nn +from torch.optim import Optimizer from lightning.pytorch.plugins.precision import MixedPrecision +from lightning.pytorch.utilities import GradClipAlgorithmType def test_clip_gradients(): @@ -62,7 +59,7 @@ def test_optimizer_amp_scaling_support_in_step_method(): def test_amp_with_no_grad(precision: str): layer = nn.Linear(2, 1) x = torch.randn(1, 2) - amp = MixedPrecision(precision=precision, device='cpu') + amp = MixedPrecision(precision=precision, device="cpu") with amp.autocast_context_manager(): with torch.no_grad(): @@ -72,4 +69,4 @@ def test_amp_with_no_grad(precision: str): loss.backward() - assert loss.grad_fn is not None \ No newline at end of file + assert loss.grad_fn is not None From 600280bd4fe1417c6e6700992472e3cf25b9982e Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 20 Jun 2025 15:47:56 +0200 Subject: [PATCH 04/11] Revert import change --- tests/tests_pytorch/plugins/precision/test_amp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index b297e5ea9e735..508f8a5d56c3b 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -18,7 +18,7 @@ from torch import nn from torch.optim import Optimizer -from lightning.pytorch.plugins.precision import MixedPrecision +from lightning.pytorch.plugins import MixedPrecision from lightning.pytorch.utilities import GradClipAlgorithmType From 14ae4d881d68ce1f5b7e7944f1c521072885cdee Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 20 Jun 2025 15:49:52 +0200 Subject: [PATCH 05/11] Format test --- tests/tests_pytorch/plugins/precision/test_amp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index 508f8a5d56c3b..447b18969e543 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -57,6 +57,8 @@ def test_optimizer_amp_scaling_support_in_step_method(): @pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"]) def test_amp_with_no_grad(precision: str): + """Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient + tracking.""" layer = nn.Linear(2, 1) x = torch.randn(1, 2) amp = MixedPrecision(precision=precision, device="cpu") @@ -66,7 +68,5 @@ def test_amp_with_no_grad(precision: str): _ = layer(x) loss = layer(x).mean() - loss.backward() - assert loss.grad_fn is not None From 064caf79d23660832820aa54c58d21acaa134734 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Jun 2025 13:50:13 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/plugins/precision/amp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 10a5fe470718d..bd34ca83d5e2b 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -113,9 +113,7 @@ def clip_gradients( def autocast_context_manager(self) -> torch.autocast: return torch.autocast( - self.device, - dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half), - cache_enabled=False + self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half), cache_enabled=False ) @override From 70023a1f5de2b5f4866d830a6d1a7de824292f7e Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 20 Jun 2025 18:22:30 +0200 Subject: [PATCH 07/11] Only test for bf16-mixed --- tests/tests_pytorch/plugins/precision/test_amp.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index 447b18969e543..3894c4256e0b8 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -55,13 +55,12 @@ def test_optimizer_amp_scaling_support_in_step_method(): precision.clip_gradients(optimizer, clip_val=1.0) -@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"]) -def test_amp_with_no_grad(precision: str): +def test_amp_with_no_grad(): """Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient tracking.""" layer = nn.Linear(2, 1) x = torch.randn(1, 2) - amp = MixedPrecision(precision=precision, device="cpu") + amp = MixedPrecision(precision="bf16-mixed", device="cpu") with amp.autocast_context_manager(): with torch.no_grad(): From b4c44015f7468e366748a8db05b7b76b6d924800 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Wed, 2 Jul 2025 22:28:53 +0200 Subject: [PATCH 08/11] Implement test to reproduce the issue --- .../checkpointing/test_model_checkpoint.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 7b17498865889..0e62c566c9573 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1666,3 +1666,30 @@ def val_dataloader(self) -> DataLoader: trainer_kwargs["max_epochs"] = 4 trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) trainer.fit(model, ckpt_path=checkpoint_path) + + +def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): + """Test that save_last=True when save_on_train_epoch_end=False""" + + # Remove validation methods to reproduce the bug + model = BoringModel() + model.validation_step = None + model.val_dataloader = None + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_on_train_epoch_end=False, + ) + + trainer = Trainer( + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_progress_bar=False, + ) + + trainer.fit(model) + + # save_last=True should always save last.ckpt + assert (tmp_path / "last.ckpt").exists() From 2ad5b6b53aebcbc08bbc32b5de7ec56328ff414e Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Wed, 2 Jul 2025 23:13:48 +0200 Subject: [PATCH 09/11] Implement fix --- src/lightning/pytorch/callbacks/model_checkpoint.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 6b7b2831a2e04..117d6bc77020e 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -338,6 +338,13 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) + @override + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Ensure save_last=True is applied when training ends.""" + if self.save_last and not self._last_checkpoint_saved: + monitor_candidates = self._monitor_candidates(trainer) + self._save_last_checkpoint(trainer, monitor_candidates) + @override def state_dict(self) -> dict[str, Any]: return { From 2c7f75d8b0772dc788e9988182d2c858676d560d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Jul 2025 21:14:39 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../checkpointing/test_model_checkpoint.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 0e62c566c9573..6c5dc63ed8440 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1669,27 +1669,27 @@ def val_dataloader(self) -> DataLoader: def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): - """Test that save_last=True when save_on_train_epoch_end=False""" + """Test that save_last=True when save_on_train_epoch_end=False.""" # Remove validation methods to reproduce the bug model = BoringModel() model.validation_step = None model.val_dataloader = None - + checkpoint_callback = ModelCheckpoint( dirpath=tmp_path, save_last=True, save_on_train_epoch_end=False, ) - + trainer = Trainer( max_epochs=2, callbacks=[checkpoint_callback], logger=False, enable_progress_bar=False, ) - + trainer.fit(model) - + # save_last=True should always save last.ckpt assert (tmp_path / "last.ckpt").exists() From 4fa3f749149177d7c3ded8a4cb88ad324237e0c2 Mon Sep 17 00:00:00 2001 From: Bas Krahmer Date: Mon, 7 Jul 2025 09:21:54 +0200 Subject: [PATCH 11/11] Update src/lightning/pytorch/plugins/precision/amp.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/pytorch/plugins/precision/amp.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index bd34ca83d5e2b..5ea62233e1f69 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -112,9 +112,8 @@ def clip_gradients( super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) def autocast_context_manager(self) -> torch.autocast: - return torch.autocast( - self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half), cache_enabled=False - ) + dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half + return torch.autocast(self.device, dtype=dtype, cache_enabled=False) @override @contextmanager