Skip to content

Commit e982940

Browse files
committed
Add independent flag to track checkpoint resumption.
Signed-off-by: sudipto baral <sudiptobaral.me@gmail.com>
1 parent dae0d89 commit e982940

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,6 @@ def skip(self) -> bool:
204204
# so we cannot use it solely
205205
return self.done or self.trainer.limit_train_batches == 0
206206

207-
@property
208-
def _is_resuming(self) -> bool:
209-
"""Whether we're resuming training from a checkpoint."""
210-
return self._loaded_from_state_dict
211-
212207
def run(self) -> None:
213208
self.setup_data()
214209
if self.skip:

src/lightning/pytorch/loops/loop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,19 @@ class _Loop:
2323
def __init__(self, trainer: "pl.Trainer") -> None:
2424
self._restarting = False
2525
self._loaded_from_state_dict = False
26+
self._resuming_from_checkpoint = False
2627
self.trainer = trainer
2728

2829
@property
2930
def restarting(self) -> bool:
3031
"""Whether the state of this loop was reloaded and it needs to restart."""
3132
return self._restarting
3233

34+
@property
35+
def is_resuming(self) -> bool:
36+
"""Whether we're resuming training from a checkpoint."""
37+
return self._resuming_from_checkpoint
38+
3339
@restarting.setter
3440
def restarting(self, restarting: bool) -> None:
3541
"""Connects this loop's restarting value and its children."""
@@ -87,6 +93,7 @@ def load_state_dict(
8793
v.load_state_dict(state_dict.copy(), prefix + k + ".")
8894
self.restarting = True
8995
self._loaded_from_state_dict = True
96+
self._resuming_from_checkpoint = True
9097

9198
def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
9299
for k, v in self.__dict__.items():
@@ -102,4 +109,5 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
102109
def on_iteration_done(self) -> None:
103110
self._restarting = False
104111
self._loaded_from_state_dict = False
112+
self._resuming_from_checkpoint = False
105113
self.reset_restart_stage()

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,9 @@ def on_run_start(self, data_fetcher: _DataFetcher) -> None:
239239
# `iter()` was called once in `FitLoop.setup_data()` already
240240
# Only call `iter()` if all following cases:
241241
# 1. Not restarting
242-
# 2. Not resuming from checkpoint (not _is_resuming)
242+
# 2. Not resuming from checkpoint (not is_resuming)
243243
# 3. Past first epoch (current_epoch > 0)
244-
if (self.trainer.current_epoch > 0 and not self.trainer.fit_loop._is_resuming) and not self.restarting:
244+
if (self.trainer.current_epoch > 0 and not self.trainer.fit_loop.is_resuming) and not self.restarting:
245245
iter(data_fetcher) # creates the iterator inside the fetcher
246246

247247
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching

0 commit comments

Comments
 (0)