Skip to content

Commit ae24e73

Browse files
committed
Fix double iteration bug when resumed from a checkpoint.
1 parent a651975 commit ae24e73

File tree

3 files changed

+91
-1
lines changed

3 files changed

+91
-1
lines changed

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def skip(self) -> bool:
204204
# so we cannot use it solely
205205
return self.done or self.trainer.limit_train_batches == 0
206206

207+
@property
208+
def _is_resuming(self) -> bool:
209+
"""Whether we're resuming training from a checkpoint."""
210+
return self._loaded_from_state_dict
211+
207212
def run(self) -> None:
208213
self.setup_data()
209214
if self.skip:

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,11 @@ def reset(self) -> None:
237237

238238
def on_run_start(self, data_fetcher: _DataFetcher) -> None:
239239
# `iter()` was called once in `FitLoop.setup_data()` already
240-
if self.trainer.current_epoch > 0 and not self.restarting:
240+
# Only call iter() if:
241+
# 1. Not restarting AND
242+
# 2. Not resuming from checkpoint (not _is_resuming) AND
243+
# 3. Past first epoch (current_epoch > 0)
244+
if (self.trainer.current_epoch > 0 and not self.trainer.fit_loop._is_resuming) and not self.restarting:
241245
iter(data_fetcher) # creates the iterator inside the fetcher
242246

243247
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# This test tests the resuming of training from a checkpoint file using an IterableDataset.
16+
# And contains code mentioned in the issue: #19427.
17+
# Ref: https://github.com/Lightning-AI/pytorch-lightning/issues/19427
18+
import multiprocessing as mp
19+
import os
20+
from collections.abc import Iterator
21+
from pathlib import Path
22+
from queue import Queue
23+
24+
import numpy as np
25+
from torch.utils.data import DataLoader, IterableDataset
26+
27+
from lightning import Trainer
28+
from lightning.pytorch.demos.boring_classes import BoringModel
29+
30+
31+
class QueueDataset(IterableDataset):
32+
def __init__(self, queue: Queue) -> None:
33+
super().__init__()
34+
self.queue = queue
35+
36+
def __iter__(self) -> Iterator:
37+
for _ in range(5):
38+
tensor, _ = self.queue.get(timeout=5)
39+
yield tensor
40+
41+
42+
def create_queue() -> Queue:
43+
q = mp.Queue()
44+
arr = np.random.random([1, 32]).astype(np.float32)
45+
for ind in range(20):
46+
q.put((arr, ind))
47+
return q
48+
49+
50+
def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> Trainer:
51+
dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None, persistent_workers=True)
52+
trainer = Trainer(
53+
max_epochs=max_epochs,
54+
enable_progress_bar=False,
55+
enable_checkpointing=False,
56+
devices=1,
57+
logger=False,
58+
)
59+
if ckpt_path.exists():
60+
trainer.fit(BoringModel(), dataloader, ckpt_path=str(ckpt_path))
61+
else:
62+
trainer.fit(BoringModel(), dataloader)
63+
trainer.save_checkpoint(str(ckpt_path))
64+
return trainer
65+
66+
67+
def test_resume_training_with(tmp_path):
68+
"""Test resuming training from checkpoint file using a IterableDataset."""
69+
queue = create_queue()
70+
max_epoch = 2
71+
ckpt_path = tmp_path / "model.ckpt"
72+
trainer = train_model(queue, max_epoch, ckpt_path)
73+
assert trainer is not None
74+
75+
assert os.path.exists(ckpt_path), f"Checkpoint file '{ckpt_path}' wasn't created"
76+
77+
ckpt_size = os.path.getsize(ckpt_path)
78+
assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)"
79+
80+
trainer = train_model(queue, max_epoch + 2, ckpt_path)
81+
assert trainer is not None

0 commit comments

Comments
 (0)