Skip to content

feat: use datasets.IterableDataset shard if possible. #3583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import importlib
import math
import sys
from contextlib import suppress
from typing import Callable, Optional, Union

Expand Down Expand Up @@ -1194,7 +1195,13 @@ def prepare_data_loader(
dataloader.sampler.generator = generator
# No change if no multiprocess
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
if isinstance(new_dataset, IterableDataset):
if (
isinstance(new_dataset, getattr(sys.modules.get("datasets"), "IterableDataset", type(None)))
Comment on lines +1198 to +1199
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could work but let's use check if dataset is available (is_datasets_available) and import the class IterableDataset from there to perform the check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I write in this style rather than use

if is_datasets_available():
    from datasets import IterableDataset as DatasetsIterableDatasets
...

if isinstance(new_dataset, DatasetsIterableDatasets):
    ...

is aiming to reduce import overhead like this codesnippet. it check the object is torch.Tensor or not, and skip to import the heavy pytorch package if there is no torch.Tensor object at all.

however, I think the import overhead of datasets is not so heavy like torch, so if it's for the readability and maintainability, I will change to this style

if is_datasets_available():
    from datasets import IterableDataset as DatasetsIterableDatasets
...

So what do you think about it, the original version or the import and check version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it should be fine with the overhead. We only call this function once so it shouldn't create a huge overhead. Please go with the import + check version.

and not split_batches
and new_dataset.n_shard > num_processes
):
new_dataset = new_dataset.shard(num_shards=num_processes, index=process_index)
elif isinstance(new_dataset, IterableDataset):
if getattr(dataloader.dataset, "generator", None) is not None:
synchronized_generator = dataloader.dataset.generator
new_dataset = IterableDatasetShard(
Expand Down