From 79f6824120df1fdef11ca93ac2497a8649b9aa77 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Wed, 2 Jul 2025 23:38:54 +0200 Subject: [PATCH 1/5] Implement test to reproduce the issue --- .../checkpointing/test_model_checkpoint.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index c911885117e29..08a0becb6a6a3 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1601,3 +1601,30 @@ def test_expand_home(): # it is possible to have a folder with the name `~` checkpoint = ModelCheckpoint(dirpath="./~/checkpoints") assert checkpoint.dirpath == str(Path.cwd() / "~" / "checkpoints") + + +def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): + """Test that save_last=True when save_on_train_epoch_end=False""" + + # Remove validation methods to reproduce the bug + model = BoringModel() + model.validation_step = None + model.val_dataloader = None + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_on_train_epoch_end=False, + ) + + trainer = Trainer( + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_progress_bar=False, + ) + + trainer.fit(model) + + # save_last=True should always save last.ckpt + assert (tmp_path / "last.ckpt").exists() From 4d706d19a78fa0a0d57268bca53786c5a008dabd Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Wed, 2 Jul 2025 23:39:04 +0200 Subject: [PATCH 2/5] Implement fix --- src/lightning/pytorch/callbacks/model_checkpoint.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 6c5dd01df15c7..2a21f86794911 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -332,6 +332,13 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) + @override + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Ensure save_last=True is applied when training ends.""" + if self.save_last and not self._last_checkpoint_saved: + monitor_candidates = self._monitor_candidates(trainer) + self._save_last_checkpoint(trainer, monitor_candidates) + @override def state_dict(self) -> Dict[str, Any]: return { From 4b780131868b25ed88736294cb00a3cb154bebc6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Jul 2025 21:46:52 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../checkpointing/test_model_checkpoint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 1cdd0ffad1dd3..f9e5d19af75e3 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1667,29 +1667,29 @@ def val_dataloader(self) -> DataLoader: trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) trainer.fit(model, ckpt_path=checkpoint_path) - + def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): - """Test that save_last=True when save_on_train_epoch_end=False""" + """Test that save_last=True when save_on_train_epoch_end=False.""" # Remove validation methods to reproduce the bug model = BoringModel() model.validation_step = None model.val_dataloader = None - + checkpoint_callback = ModelCheckpoint( dirpath=tmp_path, save_last=True, save_on_train_epoch_end=False, ) - + trainer = Trainer( max_epochs=2, callbacks=[checkpoint_callback], logger=False, enable_progress_bar=False, ) - + trainer.fit(model) - + # save_last=True should always save last.ckpt assert (tmp_path / "last.ckpt").exists() From 86cf068bb6bf039138b497bb089fe9b50ae6ffd3 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Wed, 2 Jul 2025 23:47:58 +0200 Subject: [PATCH 4/5] Post-merge fix --- tests/tests_pytorch/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index f9e5d19af75e3..6c5dc63ed8440 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1668,7 +1668,7 @@ def val_dataloader(self) -> DataLoader: trainer.fit(model, ckpt_path=checkpoint_path) - def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): +def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): """Test that save_last=True when save_on_train_epoch_end=False.""" # Remove validation methods to reproduce the bug From 5e37e07cd7ca507c9728bb0d7c7db1c426ee84fe Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Thu, 3 Jul 2025 00:01:48 +0200 Subject: [PATCH 5/5] Rephrasee comments --- tests/tests_pytorch/checkpointing/test_model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 6c5dc63ed8440..5f9781cc4a1b2 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1669,9 +1669,9 @@ def val_dataloader(self) -> DataLoader: def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): - """Test that save_last=True when save_on_train_epoch_end=False.""" + """Test that save_last=True works correctly when save_on_train_epoch_end=False in a model without validation.""" - # Remove validation methods to reproduce the bug + # Remove validation methods to test the edge case model = BoringModel() model.validation_step = None model.val_dataloader = None