You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* add IterableDataset.reshard()
* fix test and interleave_dataset after split_by_node
* update test
* add warning on shuffling seed in distributed setups
* better docstring
* typo
* dot
* better torch warning if too few shards
Copy file name to clipboardExpand all lines: docs/source/stream.mdx
+15-1Lines changed: 15 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -182,7 +182,21 @@ IterableDataset({
182
182
})
183
183
```
184
184
185
-
If your dataset has `dataset.num_shards==1`, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead.
185
+
To increase the number of shards of a dataset, you can use [`IterableDataset.reshard`]:
186
+
187
+
```py
188
+
>>> dataset.reshard()
189
+
IterableDataset({
190
+
features: ['label', 'title', 'content'],
191
+
num_shards: 3600
192
+
})
193
+
```
194
+
195
+
The resharding mechanism depends on the dataset fileformat.
196
+
For example for Parquet, it reshards using row groups instead of having one file per shard.
197
+
See how it works for every formatin [`IterableDataset.reshard`]'s documentation.
198
+
199
+
If your dataset has `dataset.num_shards==1` even after resharding, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead.
Copy file name to clipboardExpand all lines: docs/source/use_with_pytorch.mdx
+3Lines changed: 3 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -255,3 +255,6 @@ then the shards are evenly assigned across the nodes, which is the most optimize
255
255
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
256
256
257
257
This can also be combined with a `torch.utils.data.DataLoader` if you want each node to use multiple workers to load the data.
258
+
259
+
> [!WARNING]
260
+
> If you shuffle your iterable dataset in a distributed setup, make sure to set a fixed `seed` in [`IterableDataset.shuffle`] so the same shuffled list of shards is used on every node to know which shards the node should skip.
Copy file name to clipboardExpand all lines: src/datasets/combine.py
+13-4Lines changed: 13 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -39,8 +39,10 @@ def interleave_datasets(
39
39
40
40
Note for iterable datasets:
41
41
42
-
In a distributed setup or in PyTorch DataLoader workers, the stopping strategy is applied per process.
43
-
Therefore the "first_exhausted" strategy on an sharded iterable dataset can generate less samples in total (up to 1 missing sample per subdataset per worker).
42
+
* The resulting dataset's `num_shards` is the minimum of each dataset's `num_shards` to ensure good parallelism.
43
+
If some of your datasets have a very low number of shards, you may use [`IterableDataset.reshard`].
44
+
* In a distributed setup or in PyTorch DataLoader workers, the stopping strategy is applied per process.
45
+
Therefore the "first_exhausted" strategy on an sharded iterable dataset can generate less samples in total (up to 1 missing sample per subdataset per worker).
44
46
45
47
Args:
46
48
datasets (`List[Dataset]` or `List[IterableDataset]`):
@@ -170,10 +172,17 @@ def concatenate_datasets(
170
172
axis: int=0,
171
173
) ->DatasetType:
172
174
"""
173
-
Converts a list of [`Dataset`] with the same schema into a single [`Dataset`].
175
+
Concatenate several datasets (sources) into a single dataset.
176
+
177
+
Use axis=0 to concatenate vertically (default), or axis=1 to concatenate horizontally.
178
+
179
+
Note for iterable datasets:
180
+
181
+
* if axis=0, the resulting dataset's `num_shards` is the sum of each dataset's `num_shards`.
182
+
* if axis=1, the resulting dataset has one (1) shard to not misalign data.
174
183
175
184
Args:
176
-
dsets (`List[datasets.Dataset]`):
185
+
dsets (`List[datasets.Dataset]` or `List[datasets.IterableDataset]`):
177
186
List of Datasets to concatenate.
178
187
info (`DatasetInfo`, *optional*):
179
188
Dataset information, like description, citation, etc.
0 commit comments