Skip to content

Commit 025593f

Browse files
authored
Add IterableDataset.reshard() (#7992)
* add IterableDataset.reshard() * fix test and interleave_dataset after split_by_node * update test * add warning on shuffling seed in distributed setups * better docstring * typo * dot * better torch warning if too few shards
1 parent 24604d0 commit 025593f

File tree

10 files changed

+337
-241
lines changed

10 files changed

+337
-241
lines changed

docs/source/package_reference/main_classes.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
176176
- skip
177177
- take
178178
- shard
179+
- reshard
179180
- repeat
180181
- to_csv
181182
- to_pandas

docs/source/stream.mdx

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,21 @@ IterableDataset({
182182
})
183183
```
184184

185-
If your dataset has `dataset.num_shards==1`, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead.
185+
To increase the number of shards of a dataset, you can use [`IterableDataset.reshard`]:
186+
187+
```py
188+
>>> dataset.reshard()
189+
IterableDataset({
190+
features: ['label', 'title', 'content'],
191+
num_shards: 3600
192+
})
193+
```
194+
195+
The resharding mechanism depends on the dataset file format.
196+
For example for Parquet, it reshards using row groups instead of having one file per shard.
197+
See how it works for every format in [`IterableDataset.reshard`]'s documentation.
198+
199+
If your dataset has `dataset.num_shards==1` even after resharding, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead.
186200

187201
## Interleave
188202

docs/source/use_with_pytorch.mdx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,6 @@ then the shards are evenly assigned across the nodes, which is the most optimize
255255
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
256256

257257
This can also be combined with a `torch.utils.data.DataLoader` if you want each node to use multiple workers to load the data.
258+
259+
> [!WARNING]
260+
> If you shuffle your iterable dataset in a distributed setup, make sure to set a fixed `seed` in [`IterableDataset.shuffle`] so the same shuffled list of shards is used on every node to know which shards the node should skip.

src/datasets/builder.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,7 +1638,11 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_splits_
16381638
)
16391639

16401640
def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
1641-
return ExamplesIterable(self._generate_examples, split_generator.gen_kwargs)
1641+
return ExamplesIterable(
1642+
self._generate_examples,
1643+
split_generator.gen_kwargs,
1644+
generate_more_kwargs_fn=getattr(self, "_generate_more_gen_kwargs", None),
1645+
)
16421646

16431647

16441648
class ArrowBasedBuilder(DatasetBuilder):
@@ -1933,7 +1937,11 @@ def _prepare_split_single(
19331937
)
19341938

19351939
def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
1936-
return ArrowExamplesIterable(self._generate_tables, kwargs=split_generator.gen_kwargs)
1940+
return ArrowExamplesIterable(
1941+
self._generate_tables,
1942+
kwargs=split_generator.gen_kwargs,
1943+
generate_more_kwargs_fn=getattr(self, "_generate_more_gen_kwargs", None),
1944+
)
19371945

19381946

19391947
class _CountableBuilderMixin(DatasetBuilder):

src/datasets/combine.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ def interleave_datasets(
3939
4040
Note for iterable datasets:
4141
42-
In a distributed setup or in PyTorch DataLoader workers, the stopping strategy is applied per process.
43-
Therefore the "first_exhausted" strategy on an sharded iterable dataset can generate less samples in total (up to 1 missing sample per subdataset per worker).
42+
* The resulting dataset's `num_shards` is the minimum of each dataset's `num_shards` to ensure good parallelism.
43+
If some of your datasets have a very low number of shards, you may use [`IterableDataset.reshard`].
44+
* In a distributed setup or in PyTorch DataLoader workers, the stopping strategy is applied per process.
45+
Therefore the "first_exhausted" strategy on an sharded iterable dataset can generate less samples in total (up to 1 missing sample per subdataset per worker).
4446
4547
Args:
4648
datasets (`List[Dataset]` or `List[IterableDataset]`):
@@ -170,10 +172,17 @@ def concatenate_datasets(
170172
axis: int = 0,
171173
) -> DatasetType:
172174
"""
173-
Converts a list of [`Dataset`] with the same schema into a single [`Dataset`].
175+
Concatenate several datasets (sources) into a single dataset.
176+
177+
Use axis=0 to concatenate vertically (default), or axis=1 to concatenate horizontally.
178+
179+
Note for iterable datasets:
180+
181+
* if axis=0, the resulting dataset's `num_shards` is the sum of each dataset's `num_shards`.
182+
* if axis=1, the resulting dataset has one (1) shard to not misalign data.
174183
175184
Args:
176-
dsets (`List[datasets.Dataset]`):
185+
dsets (`List[datasets.Dataset]` or `List[datasets.IterableDataset]`):
177186
List of Datasets to concatenate.
178187
info (`DatasetInfo`, *optional*):
179188
Dataset information, like description, citation, etc.

src/datasets/distributed.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D
2222
then the shards are evenly assigned across the nodes, which is the most optimized.
2323
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
2424
25+
> [!WARNING]
26+
> If you shuffle your iterable dataset in a distributed setup, make sure to set a fixed `seed` in [`IterableDataset.shuffle`]
27+
so the same shuffled list of shards is used on every node to know which shards the node should skip.
28+
2529
Args:
2630
dataset ([`Dataset`] or [`IterableDataset`]):
2731
The dataset to split by node.

0 commit comments

Comments
 (0)