Skip to content

Make IterableDataset (optionally) resumable #7385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 67 additions & 15 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,10 +1519,17 @@ def num_shards(self) -> int:


class BufferShuffledExamplesIterable(_BaseExamplesIterable):
def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator):
def __init__(
self,
ex_iterable: _BaseExamplesIterable,
buffer_size: int,
stateful: bool,
generator: np.random.Generator
):
super().__init__()
self.ex_iterable = ex_iterable
self.buffer_size = buffer_size
self.stateful = stateful
self.generator = generator
# TODO(QL): implement iter_arrow

Expand All @@ -1536,12 +1543,17 @@ def features(self):

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
self._original_state_dict = self.state_dict()
self._state_dict['mem_buffer'] = ([],)
self._state_dict['bit_generator_state'] = self.generator.bit_generator.state
self._state_dict['bit_generator_index_offset'] = 0
self._state_dict['bit_generator_index_offset_shuffle'] = 0
if not self.stateful:
self._original_state_dict = self.state_dict()
return self._state_dict

def load_state_dict(self, state_dict: dict) -> dict:
if self._state_dict:
if state_dict != self._original_state_dict:
if not self.stateful and state_dict != self._original_state_dict:
logger.warning(
"Loading a state dict of a shuffle buffer of a dataset without the buffer content."
"The shuffle buffer will be refilled before starting to yield new examples."
Expand All @@ -1556,31 +1568,61 @@ def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batc
def __iter__(self):
buffer_size = self.buffer_size
rng = deepcopy(self.generator)
indices_iterator = self._iter_random_indices(rng, buffer_size)
# this is the shuffle buffer that we keep in memory
mem_buffer = []
if self.stateful and self._state_dict:
# this is the shuffle buffer that we keep in memory
mem_buffer = self._state_dict['mem_buffer'][0]
# this is an infinite iterator that randomly samples the index of the source to pick examples from
index_offset = self._state_dict["bit_generator_index_offset"]
rng.bit_generator.state = self._state_dict["bit_generator_state"]
else:
mem_buffer = []
index_offset = 0

indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size=buffer_size)
# skip already consumed ones
for _ in range(index_offset):
i = next(indices_iterator)

for x in self.ex_iterable:
if len(mem_buffer) == buffer_size: # if the buffer is full, pick and example from it
if len(mem_buffer) < buffer_size: # if the buffer is not full, keep filling the buffer
mem_buffer.append(x)
else: # otherwise, pick an example from it
i = next(indices_iterator)
yield mem_buffer[i]
index_offset = (index_offset + 1) % buffer_size
if self.stateful and self._state_dict:
self._state_dict["bit_generator_index_offset"] = index_offset
if index_offset == 0:
self._state_dict["bit_generator_state"] = rng.bit_generator.state
selected = mem_buffer[i]
mem_buffer[i] = x # replace the picked example by a new one
else: # otherwise, keep filling the buffer
mem_buffer.append(x)
yield selected

index_offset = self._state_dict["bit_generator_index_offset_shuffle"] if self._state_dict else 0
if self.stateful and self._state_dict:
rng.bit_generator.state = self._state_dict["bit_generator_state"]

# when we run out of examples, we shuffle the remaining examples in the buffer and yield them
rng.shuffle(mem_buffer)
yield from mem_buffer
for i in rng.permutation(len(mem_buffer))[index_offset:].tolist():
index_offset = index_offset + 1
if self.stateful and self._state_dict:
self._state_dict["bit_generator_index_offset_shuffle"] = index_offset
yield mem_buffer[i]

def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffledExamplesIterable":
"""Shuffle the wrapped examples iterable as well as the shuffling buffer."""
return BufferShuffledExamplesIterable(
self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
self.ex_iterable.shuffle_data_sources(generator),
buffer_size=self.buffer_size,
stateful=self.stateful,
generator=generator,
)

def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "BufferShuffledExamplesIterable":
"""Keep only the requested shard."""
return BufferShuffledExamplesIterable(
self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
buffer_size=self.buffer_size,
stateful=self.stateful,
generator=self.generator,
)

Expand Down Expand Up @@ -2688,7 +2730,11 @@ def filter(
)

def shuffle(
self, seed=None, generator: Optional[np.random.Generator] = None, buffer_size: int = 1000
self,
seed=None,
generator: Optional[np.random.Generator] = None,
buffer_size: int = 1000,
stateful: bool = False
) -> "IterableDataset":
"""
Randomly shuffles the elements of this dataset.
Expand All @@ -2715,6 +2761,9 @@ def shuffle(
If `generator=None` (default), uses `np.random.default_rng` (the default BitGenerator (PCG64) of NumPy).
buffer_size (`int`, defaults to `1000`):
Size of the buffer.
stateful (`bool`, defaults to `False`):
Whether to make the shuffling stateful.
If `stateful=True`, this will incur additional memory overhead to preserve the shuffling states across epochs.

Example:

Expand Down Expand Up @@ -2744,7 +2793,10 @@ def shuffle(
shuffling = ShufflingConfig(generator=generator, _original_seed=seed)
return IterableDataset(
ex_iterable=BufferShuffledExamplesIterable(
self._ex_iterable, buffer_size=buffer_size, generator=generator
self._ex_iterable,
buffer_size=buffer_size,
stateful=stateful,
generator=generator
),
info=self._info.copy(),
split=self._split,
Expand Down
41 changes: 29 additions & 12 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,27 @@ def gen(tables):


@pytest.mark.parametrize("seed", [42, 1337, 101010, 123456])
def test_buffer_shuffled_examples_iterable(seed):
@pytest.mark.parametrize("stateful", [False, True])
def test_buffer_shuffled_examples_iterable(seed, stateful):
n, buffer_size = 100, 30
generator = np.random.default_rng(seed)
base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
ex_iterable = BufferShuffledExamplesIterable(base_ex_iterable, buffer_size=buffer_size, generator=generator)
ex_iterable = BufferShuffledExamplesIterable(
base_ex_iterable,
buffer_size=buffer_size,
stateful=stateful,
generator=generator
)

rng = deepcopy(generator)
expected_indices_used_for_shuffling = list(
islice(BufferShuffledExamplesIterable._iter_random_indices(rng, buffer_size=buffer_size), n - buffer_size)
)
expected_indices_used_for_shuffling = list(islice(
BufferShuffledExamplesIterable._iter_random_indices(
rng=rng,
buffer_size=buffer_size,
random_batch_size=buffer_size,
),
n - buffer_size
))
# indices to pick in the shuffle buffer should all be in the right range
assert all(0 <= index_to_pick < buffer_size for index_to_pick in expected_indices_used_for_shuffling)
# it should be random indices
Expand Down Expand Up @@ -1234,7 +1245,8 @@ def test_horizontally_concatenated_examples_iterable():
MappedExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda x: x),
FilteredExamplesIterable(ExamplesIterable(generate_examples_fn, {}), lambda x: True),
FilteredExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda x: True),
BufferShuffledExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10, np.random.default_rng(42)),
BufferShuffledExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10, False, np.random.default_rng(42)),
BufferShuffledExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10, True, np.random.default_rng(42)),
SkipExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10),
TakeExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10),
FormattedExamplesIterable(
Expand All @@ -1244,7 +1256,7 @@ def test_horizontally_concatenated_examples_iterable():
)
def test_no_iter_arrow(ex_iterable: _BaseExamplesIterable):
assert ex_iterable.iter_arrow is None
if not isinstance(ex_iterable, BufferShuffledExamplesIterable):
if not isinstance(ex_iterable, BufferShuffledExamplesIterable) or ex_iterable.stateful:
assert_load_state_dict_resumes_iteration(ex_iterable)


Expand Down Expand Up @@ -1613,11 +1625,12 @@ def test_iterable_dataset_filter(dataset: IterableDataset) -> None:

@pytest.mark.parametrize("seed", [42, 1337, 101010, 123456])
@pytest.mark.parametrize("epoch", [None, 0, 1])
def test_iterable_dataset_shuffle(dataset: IterableDataset, seed, epoch):
@pytest.mark.parametrize("stateful", [False, True])
def test_iterable_dataset_shuffle(dataset: IterableDataset, seed, epoch, stateful):
buffer_size = 3
dataset = deepcopy(dataset)
dataset._ex_iterable.kwargs["filepaths"] = ["0.txt", "1.txt"]
dataset = dataset.shuffle(seed, buffer_size=buffer_size)
dataset = dataset.shuffle(seed, buffer_size=buffer_size, stateful=stateful)
assert isinstance(dataset._shuffling, ShufflingConfig)
assert isinstance(dataset._shuffling.generator, np.random.Generator)
assert is_rng_equal(dataset._shuffling.generator, np.random.default_rng(seed))
Expand All @@ -1628,9 +1641,13 @@ def test_iterable_dataset_shuffle(dataset: IterableDataset, seed, epoch):
dataset.set_epoch(epoch)
effective_seed = np.random.default_rng(seed).integers(0, 1 << 63) - epoch
# Shuffling adds a shuffle buffer
expected_first_example_index = next(
iter(BufferShuffledExamplesIterable._iter_random_indices(np.random.default_rng(effective_seed), buffer_size))
)
expected_first_example_index = next(iter(
BufferShuffledExamplesIterable._iter_random_indices(
rng=np.random.default_rng(effective_seed),
buffer_size=buffer_size,
random_batch_size=buffer_size
)
))
assert isinstance(dataset._ex_iterable, BufferShuffledExamplesIterable)
# It also shuffles the underlying examples iterable
expected_ex_iterable = ExamplesIterable(
Expand Down
Loading