Skip to content

Commit 294f35b

Browse files
qgallouedeclewtun
andauthored
☝️ [GRPO] Generate once per effective batch (#3283)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
1 parent 9874b3a commit 294f35b

File tree

3 files changed

+189
-63
lines changed

3 files changed

+189
-63
lines changed

tests/test_grpo_trainer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,3 +1038,34 @@ def test_training_with_mask_truncated_completions_all_masked(self):
10381038
for n, param in previous_trainable_params.items():
10391039
new_param = trainer.model.get_parameter(n)
10401040
self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.")
1041+
1042+
def test_training_num_generations_larger_than_batch_size(self):
1043+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
1044+
1045+
with tempfile.TemporaryDirectory() as tmp_dir:
1046+
training_args = GRPOConfig(
1047+
output_dir=tmp_dir,
1048+
learning_rate=0.1, # increase the learning rate to speed up the test
1049+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
1050+
max_completion_length=8, # reduce the completion length to reduce memory usage
1051+
num_generations=6, # the number of generations is larger than the batch size, but
1052+
gradient_accumulation_steps=2, # gradient accumulation should allow that
1053+
report_to="none",
1054+
)
1055+
trainer = GRPOTrainer(
1056+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
1057+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
1058+
args=training_args,
1059+
train_dataset=dataset,
1060+
)
1061+
1062+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1063+
1064+
trainer.train()
1065+
1066+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
1067+
1068+
# Check that the params have changed
1069+
for n, param in previous_trainable_params.items():
1070+
new_param = trainer.model.get_parameter(n)
1071+
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

trl/trainer/grpo_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ class GRPOConfig(TrainingArguments):
5050
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
5151
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
5252
num_generations (`int` or `None`, *optional*, defaults to `8`):
53-
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
54-
must be divisible by this value.
53+
Number of generations per prompt to sample. The effective batch size (num_processes *
54+
per_device_batch_size * gradient_accumulation_steps) must be evenly divisible by this value.
5555
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
5656
Maximum length of the generated completion.
5757
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
@@ -205,8 +205,8 @@ class GRPOConfig(TrainingArguments):
205205
num_generations: Optional[int] = field(
206206
default=8,
207207
metadata={
208-
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
209-
"must be divisible by this value."
208+
"help": "Number of generations to sample. The effective batch size (num_processes * per_device_batch_size "
209+
"* gradient_accumulation_steps) must be evenly divisible by this value."
210210
},
211211
)
212212
max_completion_length: Optional[int] = field(

0 commit comments

Comments
 (0)