Skip to content

Commit 0dad4eb

Browse files
🎲 [GRPO] Make training dataset shuffle optional (#3334)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
1 parent c82f626 commit 0dad4eb

File tree

3 files changed

+50
-16
lines changed

3 files changed

+50
-16
lines changed

tests/test_grpo_trainer.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from transformers.utils import is_peft_available
2525

2626
from trl import GRPOConfig, GRPOTrainer
27-
from trl.trainer.grpo_trainer import RepeatRandomSampler
27+
from trl.trainer.grpo_trainer import RepeatSampler
2828

2929
from .testing_utils import require_vllm
3030

@@ -36,7 +36,7 @@
3636
class RepeatRandomSamplerTester(unittest.TestCase):
3737
def test_sampler(self):
3838
dataset = ["a", "b", "c", "d", "e", "f", "g"]
39-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2)
39+
sampler = RepeatSampler(dataset, mini_repeat_count=2)
4040
# Should output something like [4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 5, 5]
4141
sampled = list(sampler)
4242
# Check that the length is doubled
@@ -46,9 +46,16 @@ def test_sampler(self):
4646
# Check that each element is repeated twice
4747
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
4848

49+
def test_sampler_no_shuffle(self):
50+
dataset = ["a", "b", "c", "d", "e", "f", "g"]
51+
sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=False)
52+
sampled = list(sampler)
53+
expected = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]
54+
self.assertEqual(sampled, expected)
55+
4956
def test_sampler_no_repeat(self):
5057
dataset = ["a", "b", "c", "d", "e", "f", "g"]
51-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1)
58+
sampler = RepeatSampler(dataset, mini_repeat_count=1)
5259
# Should output something like [4, 3, 0, 1, 2, 6, 5]
5360
sampled = list(sampler)
5461
# Check that the length is the same
@@ -58,7 +65,7 @@ def test_sampler_no_repeat(self):
5865

5966
def test_sampler_with_batch_size(self):
6067
dataset = ["a", "b", "c", "d", "e", "f", "g", "h"]
61-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
68+
sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
6269
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6, 5, 7, 5, 7]
6370
sampled = list(sampler)
6471
# Check that the length is doubled
@@ -70,7 +77,7 @@ def test_sampler_with_batch_size(self):
7077

7178
def test_sampler_with_batch_size_and_drop(self):
7279
dataset = ["a", "b", "c", "d", "e", "f", "g"]
73-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
80+
sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
7481
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6]
7582
sampled = list(sampler)
7683
# Check that the length is doubled
@@ -84,7 +91,7 @@ def test_sampler_with_batch_size_and_drop(self):
8491

8592
def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
8693
dataset = ["a", "b", "c", "d", "e", "f", "g"]
87-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2)
94+
sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2)
8895
# Should output something like [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0,
8996
# 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6]
9097
sampled = list(sampler)
@@ -100,7 +107,7 @@ def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
100107

101108
def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
102109
dataset = ["a", "b", "c", "d", "e", "f", "g"]
103-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2)
110+
sampler = RepeatSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2)
104111
# Should output something like [4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3,
105112
# 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
106113
# 2, 2, 2, 6, 6, 6, 2, 2, 2, 6, 6, 6]
@@ -118,7 +125,7 @@ def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
118125

119126
def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
120127
dataset = ["a", "b", "c", "d", "e", "f", "g"]
121-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3)
128+
sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3)
122129
# Should output something like [4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3,
123130
# 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
124131
# 2, 2, 6, 6, 2, 2, 6, 6, 2, 2, 6, 6]

trl/trainer/grpo_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class GRPOConfig(TrainingArguments):
5959
improving generation speed. However, disabling this option allows training models that exceed the VRAM
6060
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
6161
with vLLM generation.
62+
shuffle_dataset (`bool`, *optional*, defaults to `True`):
63+
Whether to shuffle the training dataset.
6264
6365
> Parameters that control generation
6466
@@ -222,6 +224,10 @@ class GRPOConfig(TrainingArguments):
222224
"is not compatible with vLLM generation."
223225
},
224226
)
227+
shuffle_dataset: Optional[bool] = field(
228+
default=True,
229+
metadata={"help": "Whether to shuffle the training dataset."},
230+
)
225231

226232
# Parameters that control generation
227233
temperature: float = field(

trl/trainer/grpo_trainer.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
7979

8080

81-
class RepeatRandomSampler(Sampler):
81+
class RepeatSampler(Sampler):
8282
"""
8383
Sampler that repeats the indices of a dataset in a structured manner.
8484
@@ -91,6 +91,8 @@ class RepeatRandomSampler(Sampler):
9191
Number of unique indices per batch.
9292
repeat_count (`int`, *optional*, defaults to `1`):
9393
Number of times to repeat the full sampling process.
94+
shuffle (`bool`, *optional*, defaults to `True`):
95+
Whether to shuffle the dataset.
9496
seed (`int` or `None`, *optional*, defaults to `None`):
9597
Random seed for reproducibility (only affects this sampler).
9698
@@ -132,21 +134,28 @@ def __init__(
132134
mini_repeat_count: int,
133135
batch_size: int = 1,
134136
repeat_count: int = 1,
137+
shuffle: bool = True,
135138
seed: Optional[int] = None,
136139
):
137140
self.data_source = data_source
138141
self.mini_repeat_count = mini_repeat_count
139142
self.batch_size = batch_size
140143
self.repeat_count = repeat_count
141144
self.num_samples = len(data_source)
145+
self.shuffle = shuffle
142146
self.seed = seed
143-
self.generator = torch.Generator() # Create a local random generator
144-
if seed is not None:
145-
self.generator.manual_seed(seed)
147+
148+
if shuffle:
149+
self.generator = torch.Generator() # Create a local random generator
150+
if seed is not None:
151+
self.generator.manual_seed(seed)
146152

147153
def __iter__(self):
148-
# E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
149-
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
154+
if self.shuffle:
155+
# E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
156+
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
157+
else:
158+
indexes = list(range(self.num_samples))
150159

151160
# [2, 4, 3, 1, 0, 6, 5]
152161
# -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
@@ -166,6 +175,15 @@ def __len__(self) -> int:
166175
return self.num_samples * self.mini_repeat_count * self.repeat_count
167176

168177

178+
class RepeatRandomSampler(RepeatSampler):
179+
def __init__(self, *args, **kwargs):
180+
warnings.warn(
181+
"RepeatRandomSampler is deprecated and will be removed in version 0.18. Use RepeatSampler instead.",
182+
DeprecationWarning,
183+
)
184+
super().__init__(*args, **kwargs)
185+
186+
169187
# torch.nanstd doesn't exist, so we define it here
170188
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
171189
"""
@@ -481,6 +499,8 @@ def data_collator(features): # No data collation is needed in GRPO
481499
self.mask_truncated_completions = args.mask_truncated_completions
482500

483501
# Datasets
502+
self.shuffle_dataset = args.shuffle_dataset
503+
484504
if (
485505
isinstance(train_dataset, IterableDataset)
486506
or isinstance(eval_dataset, IterableDataset)
@@ -727,17 +747,18 @@ def _get_train_sampler(self) -> Sampler:
727747
* self.accelerator.num_processes
728748
* self.args.gradient_accumulation_steps
729749
)
730-
return RepeatRandomSampler(
750+
return RepeatSampler(
731751
data_source=self.train_dataset,
732752
mini_repeat_count=self.num_generations,
733753
batch_size=effective_batch_size // self.num_generations,
734754
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
755+
shuffle=self.shuffle_dataset,
735756
seed=self.args.seed,
736757
)
737758

738759
def _get_eval_sampler(self, eval_dataset) -> Sampler:
739760
# See _get_train_sampler for an explanation of the sampler.
740-
return RepeatRandomSampler(
761+
return RepeatSampler(
741762
data_source=eval_dataset,
742763
mini_repeat_count=self.num_generations,
743764
seed=self.args.seed,

0 commit comments

Comments
 (0)