Skip to content

Commit 9525a9a

Browse files
committed
Add independent flag to track checkpoint resumption.
Signed-off-by: sudipto baral <sudiptobaral.me@gmail.com>
1 parent 49b2f4c commit 9525a9a

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def skip(self) -> bool:
207207
@property
208208
def _is_resuming(self) -> bool:
209209
"""Whether we're resuming training from a checkpoint."""
210-
return self._loaded_from_state_dict
210+
return self._resuming_from_checkpoint
211211

212212
def run(self) -> None:
213213
self.setup_data()

src/lightning/pytorch/loops/loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class _Loop:
2323
def __init__(self, trainer: "pl.Trainer") -> None:
2424
self._restarting = False
2525
self._loaded_from_state_dict = False
26+
self._resuming_from_checkpoint = False
2627
self.trainer = trainer
2728

2829
@property
@@ -87,6 +88,7 @@ def load_state_dict(
8788
v.load_state_dict(state_dict.copy(), prefix + k + ".")
8889
self.restarting = True
8990
self._loaded_from_state_dict = True
91+
self._resuming_from_checkpoint = True
9092

9193
def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
9294
for k, v in self.__dict__.items():
@@ -102,4 +104,5 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
102104
def on_iteration_done(self) -> None:
103105
self._restarting = False
104106
self._loaded_from_state_dict = False
107+
self._resuming_from_checkpoint = False
105108
self.reset_restart_stage()

0 commit comments

Comments
 (0)