diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 3163312b568..8bcf926e34f 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -14,6 +14,7 @@ import importlib import math +import sys from contextlib import suppress from typing import Callable, Optional, Union @@ -1194,7 +1195,13 @@ def prepare_data_loader( dataloader.sampler.generator = generator # No change if no multiprocess if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches: - if isinstance(new_dataset, IterableDataset): + if ( + isinstance(new_dataset, getattr(sys.modules.get("datasets"), "IterableDataset", type(None))) + and not split_batches + and new_dataset.n_shard > num_processes + ): + new_dataset = new_dataset.shard(num_shards=num_processes, index=process_index) + elif isinstance(new_dataset, IterableDataset): if getattr(dataloader.dataset, "generator", None) is not None: synchronized_generator = dataloader.dataset.generator new_dataset = IterableDatasetShard(