@@ -107,6 +107,37 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
107
107
def step (self , step_idx : int , ** kwargs ) -> Optional [float ]:
108
108
raise NotImplementedError
109
109
110
+ def prepare_mini_batch (self , effective_group_to_raw_group_mapping : Dict [int , int ]) -> Dict [str , torch .Tensor ]:
111
+ """
112
+ Prepare a mini-batch from the effective group to raw group mapping.
113
+ This method is used to create a mini-batch for training.
114
+ """
115
+ batches = [
116
+ self .buffer [effective_group_to_raw_group_mapping [i ]]
117
+ for i in range (self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size )
118
+ ]
119
+ # every dp_rank will receive a complete mini-batch, no need to sync within step() later
120
+ # each mini-batch use the first self.dp_size * minibatch_size effective samples
121
+ raw_mini_batches = self .buffer [
122
+ : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
123
+ ] # include the last effective sample
124
+ raw_mini_batches_metric_dict = {
125
+ "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
126
+ "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
127
+ "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
128
+ "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
129
+ }
130
+ batch = bind_batch ([t [0 ] for t in batches ])
131
+ batch = post_recv (batch )
132
+ return batch , raw_mini_batches_metric_dict
133
+
134
+ def calculate_effective_group_to_raw_group_mapping (self ):
135
+ effective_group_to_raw_group_mapping = {}
136
+ for buffer_idx in range (len (self .buffer )):
137
+ if self .buffer [buffer_idx ][0 ] is not None :
138
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = buffer_idx
139
+ return effective_group_to_raw_group_mapping
140
+
110
141
def loop (self ) -> None :
111
142
print (
112
143
f"Consumer{ self .rank } num_update: { self .num_update_per_episode } , num_recv: { self .num_recv_per_update } , nmb: { self .num_microbatches } "
@@ -121,6 +152,38 @@ def loop(self) -> None:
121
152
torch .cuda .reset_peak_memory_stats ()
122
153
i = 0
123
154
for _ in range (self .num_recv_per_update ):
155
+ # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
156
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping ()
157
+ while len (effective_group_to_raw_group_mapping ) > max (
158
+ self .dp_size * self .batch_size
159
+ - self .dp_size
160
+ * self .minibatch_size
161
+ * self .grpo_config .get ("num_minibatch_during_rollout" , 1 ),
162
+ self .dp_size * self .minibatch_size ,
163
+ ):
164
+ self .profiler .log (
165
+ f"Still have { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
166
+ )
167
+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
168
+ effective_group_to_raw_group_mapping
169
+ )
170
+ self .profiler .enter ("step" )
171
+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
172
+ self .profiler .exit ("step" )
173
+ self .buffer = self .buffer [
174
+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
175
+ ]
176
+ # recalculate the effective group to raw group mapping
177
+ effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
178
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping ()
179
+ assert (
180
+ len (effective_group_to_raw_group_mapping )
181
+ == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
182
+ )
183
+ if loss is not None :
184
+ pbar .set_postfix ({"loss" : loss })
185
+ i += 1
186
+
124
187
# receive data from producers
125
188
for r in range (self .num_producers ):
126
189
print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
@@ -170,37 +233,20 @@ def loop(self) -> None:
170
233
f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch )} -> { torch .sum (effective_group_mask ).cpu ().item ()} effective groups"
171
234
)
172
235
# mapping the effective group to the raw group for indexing
173
- effective_group_to_raw_group_mapping = {}
174
- for buffer_idx in range (len (self .buffer )):
175
- if self .buffer [buffer_idx ][0 ] is not None :
176
- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
177
- buffer_idx
178
- )
236
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping ()
179
237
print (
180
238
f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
181
239
)
182
240
183
- while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
241
+ while len (effective_group_to_raw_group_mapping ) > self .dp_size * self .batch_size :
242
+ self .profiler .log (
243
+ f"Received { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .batch_size } , start training after recv"
244
+ )
245
+ # always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
184
246
# on each dp_rank, we use minibatch_size effective samples to form a batch
185
- batches = [
186
- self .buffer [effective_group_to_raw_group_mapping [i ]]
187
- for i in range (
188
- self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size
189
- )
190
- ]
191
- # every dp_rank will receive a complete mini-batch, no need to sync within step() later
192
- # each mini-batch use the first self.dp_size * minibatch_size effective samples
193
- raw_mini_batches = self .buffer [
194
- : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
195
- ] # include the last effective sample
196
- raw_mini_batches_metric_dict = {
197
- "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
198
- "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
199
- "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
200
- "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
201
- }
202
- batch = bind_batch ([t [0 ] for t in batches ])
203
- batch = post_recv (batch )
247
+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
248
+ effective_group_to_raw_group_mapping
249
+ )
204
250
self .profiler .enter ("step" )
205
251
loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
206
252
self .profiler .exit ("step" )
@@ -209,12 +255,7 @@ def loop(self) -> None:
209
255
]
210
256
# recalculate the effective group to raw group mapping
211
257
effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
212
- effective_group_to_raw_group_mapping = {}
213
- for buffer_idx in range (len (self .buffer )):
214
- if self .buffer [buffer_idx ][0 ] is not None :
215
- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
216
- buffer_idx
217
- )
258
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping ()
218
259
assert (
219
260
len (effective_group_to_raw_group_mapping )
220
261
== effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
0 commit comments