Skip to content

Commit c2561f8

Browse files
committed
fix bugs
1 parent ff6696a commit c2561f8

File tree

6 files changed

+90
-38
lines changed

6 files changed

+90
-38
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,37 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
107107
def step(self, step_idx: int, **kwargs) -> Optional[float]:
108108
raise NotImplementedError
109109

110+
def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
111+
"""
112+
Prepare a mini-batch from the effective group to raw group mapping.
113+
This method is used to create a mini-batch for training.
114+
"""
115+
batches = [
116+
self.buffer[effective_group_to_raw_group_mapping[i]]
117+
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
118+
]
119+
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
120+
# each mini-batch use the first self.dp_size * minibatch_size effective samples
121+
raw_mini_batches = self.buffer[
122+
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
123+
] # include the last effective sample
124+
raw_mini_batches_metric_dict = {
125+
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
126+
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
127+
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
128+
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
129+
}
130+
batch = bind_batch([t[0] for t in batches])
131+
batch = post_recv(batch)
132+
return batch, raw_mini_batches_metric_dict
133+
134+
def calculate_effective_group_to_raw_group_mapping(self):
135+
effective_group_to_raw_group_mapping = {}
136+
for buffer_idx in range(len(self.buffer)):
137+
if self.buffer[buffer_idx][0] is not None:
138+
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
139+
return effective_group_to_raw_group_mapping
140+
110141
def loop(self) -> None:
111142
print(
112143
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
@@ -121,6 +152,38 @@ def loop(self) -> None:
121152
torch.cuda.reset_peak_memory_stats()
122153
i = 0
123154
for _ in range(self.num_recv_per_update):
155+
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
156+
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
157+
while len(effective_group_to_raw_group_mapping) > max(
158+
self.dp_size * self.batch_size
159+
- self.dp_size
160+
* self.minibatch_size
161+
* self.grpo_config.get("num_minibatch_during_rollout", 1),
162+
self.dp_size * self.minibatch_size,
163+
):
164+
self.profiler.log(
165+
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
166+
)
167+
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
168+
effective_group_to_raw_group_mapping
169+
)
170+
self.profiler.enter("step")
171+
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
172+
self.profiler.exit("step")
173+
self.buffer = self.buffer[
174+
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
175+
]
176+
# recalculate the effective group to raw group mapping
177+
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
178+
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
179+
assert (
180+
len(effective_group_to_raw_group_mapping)
181+
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
182+
)
183+
if loss is not None:
184+
pbar.set_postfix({"loss": loss})
185+
i += 1
186+
124187
# receive data from producers
125188
for r in range(self.num_producers):
126189
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
@@ -170,37 +233,20 @@ def loop(self) -> None:
170233
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
171234
)
172235
# mapping the effective group to the raw group for indexing
173-
effective_group_to_raw_group_mapping = {}
174-
for buffer_idx in range(len(self.buffer)):
175-
if self.buffer[buffer_idx][0] is not None:
176-
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
177-
buffer_idx
178-
)
236+
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
179237
print(
180238
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
181239
)
182240

183-
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
241+
while len(effective_group_to_raw_group_mapping) > self.dp_size * self.batch_size:
242+
self.profiler.log(
243+
f"Received {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.batch_size}, start training after recv"
244+
)
245+
# always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
184246
# on each dp_rank, we use minibatch_size effective samples to form a batch
185-
batches = [
186-
self.buffer[effective_group_to_raw_group_mapping[i]]
187-
for i in range(
188-
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
189-
)
190-
]
191-
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
192-
# each mini-batch use the first self.dp_size * minibatch_size effective samples
193-
raw_mini_batches = self.buffer[
194-
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
195-
] # include the last effective sample
196-
raw_mini_batches_metric_dict = {
197-
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
198-
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
199-
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
200-
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
201-
}
202-
batch = bind_batch([t[0] for t in batches])
203-
batch = post_recv(batch)
247+
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
248+
effective_group_to_raw_group_mapping
249+
)
204250
self.profiler.enter("step")
205251
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
206252
self.profiler.exit("step")
@@ -209,12 +255,7 @@ def loop(self) -> None:
209255
]
210256
# recalculate the effective group to raw group mapping
211257
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
212-
effective_group_to_raw_group_mapping = {}
213-
for buffer_idx in range(len(self.buffer)):
214-
if self.buffer[buffer_idx][0] is not None:
215-
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
216-
buffer_idx
217-
)
258+
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
218259
assert (
219260
len(effective_group_to_raw_group_mapping)
220261
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def _criterion(outputs, inputs):
379379
reference_model_logits / self.generate_config["temperature"],
380380
input_ids_forward_micro_batch,
381381
num_action,
382-
self.plugin.shard_config,
382+
shard_config=self.plugin.shard_config,
383383
)
384384
per_token_kl = (
385385
torch.exp(reference_action_log_probs - action_log_probs)

applications/ColossalChat/coati/distributed/profiling_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import os
2+
import time
3+
4+
15
class CustomProfiler:
26
def __init__(self, name, disabled=True):
37
self.disabled = disabled
Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
22

33
# 8K context length
4+
# rm -rf *.prof
5+
# MAX_NEW_TOKENS=$((8192-512))
6+
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
7+
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png
8+
9+
# 4K context length
410
rm -rf *.prof
5-
MAX_NEW_TOKENS=$((8192-512))
6-
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
7-
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png
11+
MAX_NEW_TOKENS=$((4096-512))
12+
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 4 -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.txt
13+
python visualization.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.png

applications/ColossalChat/rl_example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@
263263
grpo_config = {
264264
"lr": args.learning_rate,
265265
"train_microbatch_size": args.train_microbatch_size,
266+
"num_minibatch_during_rollout": 1, # number of mini batches to pop out from buffer and used for training during rollout of the producer after it syncs the model. Hint, set to a proper value close to the number of mini batches for training that takes roughly the same time as the rollout of the producer. A value that is too large or too small will cause bubble time on the trainer or the producer.
266267
"beta": args.kl_coeff, # KL penalty coefficient
267268
"loss_variation": "sample_level",
268269
"reward_fn_type": args.reward_type,

applications/ColossalChat/visualization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@
7474
yticks.append(y_val)
7575
yticklabels.append(f"{actor}:{func}")
7676
for start, end in intervals:
77-
if end - start < 100:
78-
end = start + 100 # Ensure minimum length of 100ms
77+
if end - start < 6:
78+
end = start + 6 # Ensure minimum length of 100ms
7979
ax.plot(
8080
[start, end],
8181
[y_val, y_val],

0 commit comments

Comments
 (0)