diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 6b7b2831a2e04..117d6bc77020e 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -338,6 +338,13 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) + @override + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Ensure save_last=True is applied when training ends.""" + if self.save_last and not self._last_checkpoint_saved: + monitor_candidates = self._monitor_candidates(trainer) + self._save_last_checkpoint(trainer, monitor_candidates) + @override def state_dict(self) -> dict[str, Any]: return { diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 7b17498865889..5f9781cc4a1b2 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1666,3 +1666,30 @@ def val_dataloader(self) -> DataLoader: trainer_kwargs["max_epochs"] = 4 trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) trainer.fit(model, ckpt_path=checkpoint_path) + + +def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): + """Test that save_last=True works correctly when save_on_train_epoch_end=False in a model without validation.""" + + # Remove validation methods to test the edge case + model = BoringModel() + model.validation_step = None + model.val_dataloader = None + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_on_train_epoch_end=False, + ) + + trainer = Trainer( + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_progress_bar=False, + ) + + trainer.fit(model) + + # save_last=True should always save last.ckpt + assert (tmp_path / "last.ckpt").exists()