diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 6b7b2831a2e04..117d6bc77020e 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -338,6 +338,13 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) + @override + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Ensure save_last=True is applied when training ends.""" + if self.save_last and not self._last_checkpoint_saved: + monitor_candidates = self._monitor_candidates(trainer) + self._save_last_checkpoint(trainer, monitor_candidates) + @override def state_dict(self) -> dict[str, Any]: return { diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 75e792af46b90..5ea62233e1f69 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -112,7 +112,8 @@ def clip_gradients( super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) def autocast_context_manager(self) -> torch.autocast: - return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)) + dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half + return torch.autocast(self.device, dtype=dtype, cache_enabled=False) @override @contextmanager diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 7b17498865889..6c5dc63ed8440 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1666,3 +1666,30 @@ def val_dataloader(self) -> DataLoader: trainer_kwargs["max_epochs"] = 4 trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) trainer.fit(model, ckpt_path=checkpoint_path) + + +def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): + """Test that save_last=True when save_on_train_epoch_end=False.""" + + # Remove validation methods to reproduce the bug + model = BoringModel() + model.validation_step = None + model.val_dataloader = None + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_on_train_epoch_end=False, + ) + + trainer = Trainer( + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_progress_bar=False, + ) + + trainer.fit(model) + + # save_last=True should always save last.ckpt + assert (tmp_path / "last.ckpt").exists() diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index cb061c540b2be..3894c4256e0b8 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,6 +14,8 @@ from unittest.mock import Mock import pytest +import torch +from torch import nn from torch.optim import Optimizer from lightning.pytorch.plugins import MixedPrecision @@ -51,3 +53,19 @@ def test_optimizer_amp_scaling_support_in_step_method(): with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): precision.clip_gradients(optimizer, clip_val=1.0) + + +def test_amp_with_no_grad(): + """Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient + tracking.""" + layer = nn.Linear(2, 1) + x = torch.randn(1, 2) + amp = MixedPrecision(precision="bf16-mixed", device="cpu") + + with amp.autocast_context_manager(): + with torch.no_grad(): + _ = layer(x) + + loss = layer(x).mean() + loss.backward() + assert loss.grad_fn is not None