@@ -122,26 +122,102 @@ def loop(self) -> None:
122
122
# receive data from producers
123
123
for r in range (self .num_producers ):
124
124
print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
125
- self .buffer .extend (
126
- unbind_batch (
127
- ray_broadcast_tensor_dict (
128
- None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
129
- )
130
- )
125
+ raw_batch = ray_broadcast_tensor_dict (
126
+ None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
131
127
)
132
- while len (self .buffer ) >= self .dp_size * self .minibatch_size :
133
- batches = self .buffer [
134
- self .dp_rank * self .minibatch_size : (self .dp_rank + 1 ) * self .minibatch_size
128
+ # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
129
+ # we need to calculate the metrics before filtering here for logging
130
+ # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
131
+ raw_batch_with_reward = self .calculate_reward (
132
+ {k : v .view (- 1 , v .size (- 1 )) if k != "temperature" else v for k , v in raw_batch .items ()}
133
+ )
134
+ raw_batch_with_reward = {
135
+ k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != "temperature" else v
136
+ for k , v in raw_batch_with_reward .items ()
137
+ }
138
+ # [batch_size, num_generations] -> [batch_size]
139
+ reward = raw_batch_with_reward ["reward" ][:, :, 0 ]
140
+ format_acc = raw_batch_with_reward ["format_acc" ][:, :, 0 ]
141
+ ans_acc = raw_batch_with_reward ["ans_acc" ][:, :, 0 ]
142
+ response_len = (
143
+ raw_batch_with_reward ["response_idx" ][:, :, 1 ]
144
+ - raw_batch_with_reward ["response_idx" ][:, :, 0 ]
145
+ + 1
146
+ ).type (torch .float32 )
147
+ effective_group_mask = None
148
+ if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , True ):
149
+ # filter the group based on the reward and accuracy
150
+ group_ans_acc_mean = ans_acc .mean (dim = 1 )
151
+ effective_group_mask = torch .logical_and (
152
+ group_ans_acc_mean > self .filter_range [0 ], group_ans_acc_mean < self .filter_range [1 ]
153
+ )
154
+ raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
155
+ for group_idx , group_with_reward in enumerate (raw_batch_with_reward ):
156
+ self .buffer .append (
157
+ [
158
+ (
159
+ group_with_reward
160
+ if effective_group_mask is None or effective_group_mask [group_idx ]
161
+ else None
162
+ ),
163
+ reward [group_idx ],
164
+ format_acc [group_idx ],
165
+ ans_acc [group_idx ],
166
+ response_len [group_idx ],
167
+ ]
168
+ )
169
+ if effective_group_mask is not None :
170
+ print (
171
+ f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch_with_reward )} -> { torch .sum (effective_group_mask ).cpu ().item ()} effective groups"
172
+ )
173
+ # mapping the effective group to the raw group for indexing
174
+ effective_group_to_raw_group_mapping = {}
175
+ for buffer_idx in range (len (self .buffer )):
176
+ if self .buffer [buffer_idx ][0 ] is not None :
177
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
178
+ buffer_idx
179
+ )
180
+ print (
181
+ f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
182
+ )
183
+
184
+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
185
+ # on each dp_rank, we use minibatch_size effective samples to form a batch
186
+ batches = [
187
+ self .buffer [effective_group_to_raw_group_mapping [i ]]
188
+ for i in range (
189
+ self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size
190
+ )
135
191
]
136
- batch = bind_batch (batches )
192
+ # every dp_rank will receive a complete mini-batch, no need to sync within step() later
193
+ # each mini-batch use the first self.dp_size * minibatch_size effective samples
194
+ raw_mini_batches = self .buffer [
195
+ : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
196
+ ] # include the last effective sample
197
+ raw_mini_batches_metric_dict = {
198
+ "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
199
+ "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
200
+ "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
201
+ "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
202
+ }
203
+ batch = bind_batch ([t [0 ] for t in batches ])
137
204
batch = post_recv (batch )
138
- loss , excessive_prompts_idx = self .step (i , pbar , ** batch )
139
-
140
- if excessive_prompts_idx is not None :
141
- excessive_prompts = [self .buffer [idx ] for idx in excessive_prompts_idx ]
142
- self .buffer = excessive_prompts + self .buffer [self .dp_size * self .minibatch_size :]
143
- else :
144
- self .buffer = self .buffer [self .dp_size * self .minibatch_size :]
205
+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
206
+ self .buffer = self .buffer [
207
+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
208
+ ]
209
+ # recalculate the effective group to raw group mapping
210
+ effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
211
+ effective_group_to_raw_group_mapping = {}
212
+ for buffer_idx in range (len (self .buffer )):
213
+ if self .buffer [buffer_idx ][0 ] is not None :
214
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
215
+ buffer_idx
216
+ )
217
+ assert (
218
+ len (effective_group_to_raw_group_mapping )
219
+ == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
220
+ )
145
221
if loss is not None :
146
222
pbar .set_postfix ({"loss" : loss })
147
223
i += 1
0 commit comments