Skip to content

Commit 9379a89

Browse files
xysheng-colossalYeAnbangTongLi3701Tong Lipre-commit-ci[bot]
authored
[feat][npu] Merge form grpo-latest (#6346)
* move prompt-level-filtering to buffer side * move prompt-level-filtering to buffer side * remove redundant code and fix bugs * fix metric calculation * fix missing tags parameter * address conversation * add overlength sample count (#6332) Co-authored-by: Tong Li <tong.li35271158@gmail.com> * address conversation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typ and parameter description * [feat] Update requriments and set return logits False --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 84f523a commit 9379a89

File tree

8 files changed

+269
-143
lines changed

8 files changed

+269
-143
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -122,26 +122,102 @@ def loop(self) -> None:
122122
# receive data from producers
123123
for r in range(self.num_producers):
124124
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
125-
self.buffer.extend(
126-
unbind_batch(
127-
ray_broadcast_tensor_dict(
128-
None, src=0, device=self.device, group_name=f"sync_data_{r}"
129-
)
130-
)
125+
raw_batch = ray_broadcast_tensor_dict(
126+
None, src=0, device=self.device, group_name=f"sync_data_{r}"
131127
)
132-
while len(self.buffer) >= self.dp_size * self.minibatch_size:
133-
batches = self.buffer[
134-
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
128+
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
129+
# we need to calculate the metrics before filtering here for logging
130+
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
131+
raw_batch_with_reward = self.calculate_reward(
132+
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
133+
)
134+
raw_batch_with_reward = {
135+
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
136+
for k, v in raw_batch_with_reward.items()
137+
}
138+
# [batch_size, num_generations] -> [batch_size]
139+
reward = raw_batch_with_reward["reward"][:, :, 0]
140+
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
141+
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
142+
response_len = (
143+
raw_batch_with_reward["response_idx"][:, :, 1]
144+
- raw_batch_with_reward["response_idx"][:, :, 0]
145+
+ 1
146+
).type(torch.float32)
147+
effective_group_mask = None
148+
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
149+
# filter the group based on the reward and accuracy
150+
group_ans_acc_mean = ans_acc.mean(dim=1)
151+
effective_group_mask = torch.logical_and(
152+
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
153+
)
154+
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
155+
for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
156+
self.buffer.append(
157+
[
158+
(
159+
group_with_reward
160+
if effective_group_mask is None or effective_group_mask[group_idx]
161+
else None
162+
),
163+
reward[group_idx],
164+
format_acc[group_idx],
165+
ans_acc[group_idx],
166+
response_len[group_idx],
167+
]
168+
)
169+
if effective_group_mask is not None:
170+
print(
171+
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
172+
)
173+
# mapping the effective group to the raw group for indexing
174+
effective_group_to_raw_group_mapping = {}
175+
for buffer_idx in range(len(self.buffer)):
176+
if self.buffer[buffer_idx][0] is not None:
177+
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
178+
buffer_idx
179+
)
180+
print(
181+
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
182+
)
183+
184+
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
185+
# on each dp_rank, we use minibatch_size effective samples to form a batch
186+
batches = [
187+
self.buffer[effective_group_to_raw_group_mapping[i]]
188+
for i in range(
189+
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
190+
)
135191
]
136-
batch = bind_batch(batches)
192+
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
193+
# each mini-batch use the first self.dp_size * minibatch_size effective samples
194+
raw_mini_batches = self.buffer[
195+
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
196+
] # include the last effective sample
197+
raw_mini_batches_metric_dict = {
198+
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
199+
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
200+
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
201+
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
202+
}
203+
batch = bind_batch([t[0] for t in batches])
137204
batch = post_recv(batch)
138-
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
139-
140-
if excessive_prompts_idx is not None:
141-
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
142-
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
143-
else:
144-
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
205+
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
206+
self.buffer = self.buffer[
207+
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
208+
]
209+
# recalculate the effective group to raw group mapping
210+
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
211+
effective_group_to_raw_group_mapping = {}
212+
for buffer_idx in range(len(self.buffer)):
213+
if self.buffer[buffer_idx][0] is not None:
214+
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
215+
buffer_idx
216+
)
217+
assert (
218+
len(effective_group_to_raw_group_mapping)
219+
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
220+
)
145221
if loss is not None:
146222
pbar.set_postfix({"loss": loss})
147223
i += 1

0 commit comments

Comments
 (0)