Skip to content

Commit 91ef190

Browse files
tchatonthomas
andauthored
StreamingDataset: Fault Tolerance v2 2/n (#19201)
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
1 parent 8588032 commit 91ef190

File tree

5 files changed

+220
-254
lines changed

5 files changed

+220
-254
lines changed

src/lightning/data/streaming/dataloader.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,10 @@ def __iter__(self) -> Any:
370370
else:
371371
yield from super().__iter__()
372372

373-
def state_dict(self) -> Optional[Dict[str, Any]]:
373+
def state_dict(self) -> Dict[str, Any]:
374374
if isinstance(self.dataset, StreamingDataset):
375375
assert self.batch_size
376-
env = _DistributedEnv.detect()
377-
num_samples = self.num_samples_yielded * env.world_size
376+
num_samples = self.num_samples_yielded
378377
return self.dataset.state_dict(num_samples, self.num_workers, self.batch_size)
379378
return self.dataset.state_dict(self.num_workers, self.batch_size)
380379

@@ -384,7 +383,7 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:
384383
This is called on each copy of the dataset when resuming.
385384
386385
Args:
387-
obj (Dict[str, Any]): The state.
386+
obj (Any): The state.
388387
389388
"""
390389
if isinstance(self.dataset, (StreamingDataset, CombinedStreamingDataset)):

0 commit comments

Comments
 (0)