From 47ca49f3e63b1b8d81e652007dd812f2bebef8e7 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Tue, 4 Feb 2025 15:54:48 +0000 Subject: [PATCH] Make IterableDataset (optionally) resumable --- src/datasets/iterable_dataset.py | 82 ++++++++++++++++++++++++++------ tests/test_iterable_dataset.py | 41 +++++++++++----- 2 files changed, 96 insertions(+), 27 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 317cc0b1723..1d330e8439a 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1519,10 +1519,17 @@ def num_shards(self) -> int: class BufferShuffledExamplesIterable(_BaseExamplesIterable): - def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator): + def __init__( + self, + ex_iterable: _BaseExamplesIterable, + buffer_size: int, + stateful: bool, + generator: np.random.Generator + ): super().__init__() self.ex_iterable = ex_iterable self.buffer_size = buffer_size + self.stateful = stateful self.generator = generator # TODO(QL): implement iter_arrow @@ -1536,12 +1543,17 @@ def features(self): def _init_state_dict(self) -> dict: self._state_dict = self.ex_iterable._init_state_dict() - self._original_state_dict = self.state_dict() + self._state_dict['mem_buffer'] = ([],) + self._state_dict['bit_generator_state'] = self.generator.bit_generator.state + self._state_dict['bit_generator_index_offset'] = 0 + self._state_dict['bit_generator_index_offset_shuffle'] = 0 + if not self.stateful: + self._original_state_dict = self.state_dict() return self._state_dict def load_state_dict(self, state_dict: dict) -> dict: if self._state_dict: - if state_dict != self._original_state_dict: + if not self.stateful and state_dict != self._original_state_dict: logger.warning( "Loading a state dict of a shuffle buffer of a dataset without the buffer content." "The shuffle buffer will be refilled before starting to yield new examples." @@ -1556,24 +1568,53 @@ def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batc def __iter__(self): buffer_size = self.buffer_size rng = deepcopy(self.generator) - indices_iterator = self._iter_random_indices(rng, buffer_size) - # this is the shuffle buffer that we keep in memory - mem_buffer = [] + if self.stateful and self._state_dict: + # this is the shuffle buffer that we keep in memory + mem_buffer = self._state_dict['mem_buffer'][0] + # this is an infinite iterator that randomly samples the index of the source to pick examples from + index_offset = self._state_dict["bit_generator_index_offset"] + rng.bit_generator.state = self._state_dict["bit_generator_state"] + else: + mem_buffer = [] + index_offset = 0 + + indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size=buffer_size) + # skip already consumed ones + for _ in range(index_offset): + i = next(indices_iterator) + for x in self.ex_iterable: - if len(mem_buffer) == buffer_size: # if the buffer is full, pick and example from it + if len(mem_buffer) < buffer_size: # if the buffer is not full, keep filling the buffer + mem_buffer.append(x) + else: # otherwise, pick an example from it i = next(indices_iterator) - yield mem_buffer[i] + index_offset = (index_offset + 1) % buffer_size + if self.stateful and self._state_dict: + self._state_dict["bit_generator_index_offset"] = index_offset + if index_offset == 0: + self._state_dict["bit_generator_state"] = rng.bit_generator.state + selected = mem_buffer[i] mem_buffer[i] = x # replace the picked example by a new one - else: # otherwise, keep filling the buffer - mem_buffer.append(x) + yield selected + + index_offset = self._state_dict["bit_generator_index_offset_shuffle"] if self._state_dict else 0 + if self.stateful and self._state_dict: + rng.bit_generator.state = self._state_dict["bit_generator_state"] + # when we run out of examples, we shuffle the remaining examples in the buffer and yield them - rng.shuffle(mem_buffer) - yield from mem_buffer + for i in rng.permutation(len(mem_buffer))[index_offset:].tolist(): + index_offset = index_offset + 1 + if self.stateful and self._state_dict: + self._state_dict["bit_generator_index_offset_shuffle"] = index_offset + yield mem_buffer[i] def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffledExamplesIterable": """Shuffle the wrapped examples iterable as well as the shuffling buffer.""" return BufferShuffledExamplesIterable( - self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator + self.ex_iterable.shuffle_data_sources(generator), + buffer_size=self.buffer_size, + stateful=self.stateful, + generator=generator, ) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "BufferShuffledExamplesIterable": @@ -1581,6 +1622,7 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "B return BufferShuffledExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), buffer_size=self.buffer_size, + stateful=self.stateful, generator=self.generator, ) @@ -2688,7 +2730,11 @@ def filter( ) def shuffle( - self, seed=None, generator: Optional[np.random.Generator] = None, buffer_size: int = 1000 + self, + seed=None, + generator: Optional[np.random.Generator] = None, + buffer_size: int = 1000, + stateful: bool = False ) -> "IterableDataset": """ Randomly shuffles the elements of this dataset. @@ -2715,6 +2761,9 @@ def shuffle( If `generator=None` (default), uses `np.random.default_rng` (the default BitGenerator (PCG64) of NumPy). buffer_size (`int`, defaults to `1000`): Size of the buffer. + stateful (`bool`, defaults to `False`): + Whether to make the shuffling stateful. + If `stateful=True`, this will incur additional memory overhead to preserve the shuffling states across epochs. Example: @@ -2744,7 +2793,10 @@ def shuffle( shuffling = ShufflingConfig(generator=generator, _original_seed=seed) return IterableDataset( ex_iterable=BufferShuffledExamplesIterable( - self._ex_iterable, buffer_size=buffer_size, generator=generator + self._ex_iterable, + buffer_size=buffer_size, + stateful=stateful, + generator=generator ), info=self._info.copy(), split=self._split, diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index bd79863f9c3..9e4844c0e0d 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -309,16 +309,27 @@ def gen(tables): @pytest.mark.parametrize("seed", [42, 1337, 101010, 123456]) -def test_buffer_shuffled_examples_iterable(seed): +@pytest.mark.parametrize("stateful", [False, True]) +def test_buffer_shuffled_examples_iterable(seed, stateful): n, buffer_size = 100, 30 generator = np.random.default_rng(seed) base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) - ex_iterable = BufferShuffledExamplesIterable(base_ex_iterable, buffer_size=buffer_size, generator=generator) + ex_iterable = BufferShuffledExamplesIterable( + base_ex_iterable, + buffer_size=buffer_size, + stateful=stateful, + generator=generator + ) rng = deepcopy(generator) - expected_indices_used_for_shuffling = list( - islice(BufferShuffledExamplesIterable._iter_random_indices(rng, buffer_size=buffer_size), n - buffer_size) - ) + expected_indices_used_for_shuffling = list(islice( + BufferShuffledExamplesIterable._iter_random_indices( + rng=rng, + buffer_size=buffer_size, + random_batch_size=buffer_size, + ), + n - buffer_size + )) # indices to pick in the shuffle buffer should all be in the right range assert all(0 <= index_to_pick < buffer_size for index_to_pick in expected_indices_used_for_shuffling) # it should be random indices @@ -1234,7 +1245,8 @@ def test_horizontally_concatenated_examples_iterable(): MappedExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda x: x), FilteredExamplesIterable(ExamplesIterable(generate_examples_fn, {}), lambda x: True), FilteredExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda x: True), - BufferShuffledExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10, np.random.default_rng(42)), + BufferShuffledExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10, False, np.random.default_rng(42)), + BufferShuffledExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10, True, np.random.default_rng(42)), SkipExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10), TakeExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10), FormattedExamplesIterable( @@ -1244,7 +1256,7 @@ def test_horizontally_concatenated_examples_iterable(): ) def test_no_iter_arrow(ex_iterable: _BaseExamplesIterable): assert ex_iterable.iter_arrow is None - if not isinstance(ex_iterable, BufferShuffledExamplesIterable): + if not isinstance(ex_iterable, BufferShuffledExamplesIterable) or ex_iterable.stateful: assert_load_state_dict_resumes_iteration(ex_iterable) @@ -1613,11 +1625,12 @@ def test_iterable_dataset_filter(dataset: IterableDataset) -> None: @pytest.mark.parametrize("seed", [42, 1337, 101010, 123456]) @pytest.mark.parametrize("epoch", [None, 0, 1]) -def test_iterable_dataset_shuffle(dataset: IterableDataset, seed, epoch): +@pytest.mark.parametrize("stateful", [False, True]) +def test_iterable_dataset_shuffle(dataset: IterableDataset, seed, epoch, stateful): buffer_size = 3 dataset = deepcopy(dataset) dataset._ex_iterable.kwargs["filepaths"] = ["0.txt", "1.txt"] - dataset = dataset.shuffle(seed, buffer_size=buffer_size) + dataset = dataset.shuffle(seed, buffer_size=buffer_size, stateful=stateful) assert isinstance(dataset._shuffling, ShufflingConfig) assert isinstance(dataset._shuffling.generator, np.random.Generator) assert is_rng_equal(dataset._shuffling.generator, np.random.default_rng(seed)) @@ -1628,9 +1641,13 @@ def test_iterable_dataset_shuffle(dataset: IterableDataset, seed, epoch): dataset.set_epoch(epoch) effective_seed = np.random.default_rng(seed).integers(0, 1 << 63) - epoch # Shuffling adds a shuffle buffer - expected_first_example_index = next( - iter(BufferShuffledExamplesIterable._iter_random_indices(np.random.default_rng(effective_seed), buffer_size)) - ) + expected_first_example_index = next(iter( + BufferShuffledExamplesIterable._iter_random_indices( + rng=np.random.default_rng(effective_seed), + buffer_size=buffer_size, + random_batch_size=buffer_size + ) + )) assert isinstance(dataset._ex_iterable, BufferShuffledExamplesIterable) # It also shuffles the underlying examples iterable expected_ex_iterable = ExamplesIterable(