@@ -1038,3 +1038,34 @@ def test_training_with_mask_truncated_completions_all_masked(self):
1038
1038
for n , param in previous_trainable_params .items ():
1039
1039
new_param = trainer .model .get_parameter (n )
1040
1040
self .assertTrue (torch .equal (param , new_param ), f"Parameter { n } has changed." )
1041
+
1042
+ def test_training_num_generations_larger_than_batch_size (self ):
1043
+ dataset = load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "train" )
1044
+
1045
+ with tempfile .TemporaryDirectory () as tmp_dir :
1046
+ training_args = GRPOConfig (
1047
+ output_dir = tmp_dir ,
1048
+ learning_rate = 0.1 , # increase the learning rate to speed up the test
1049
+ per_device_train_batch_size = 3 , # reduce the batch size to reduce memory usage
1050
+ max_completion_length = 8 , # reduce the completion length to reduce memory usage
1051
+ num_generations = 6 , # the number of generations is larger than the batch size, but
1052
+ gradient_accumulation_steps = 2 , # gradient accumulation should allow that
1053
+ report_to = "none" ,
1054
+ )
1055
+ trainer = GRPOTrainer (
1056
+ model = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" ,
1057
+ reward_funcs = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" ,
1058
+ args = training_args ,
1059
+ train_dataset = dataset ,
1060
+ )
1061
+
1062
+ previous_trainable_params = {n : param .clone () for n , param in trainer .model .named_parameters ()}
1063
+
1064
+ trainer .train ()
1065
+
1066
+ self .assertIsNotNone (trainer .state .log_history [- 1 ]["train_loss" ])
1067
+
1068
+ # Check that the params have changed
1069
+ for n , param in previous_trainable_params .items ():
1070
+ new_param = trainer .model .get_parameter (n )
1071
+ self .assertFalse (torch .equal (param , new_param ), f"Parameter { n } has not changed." )
0 commit comments