Skip to content

Commit f9abaa8

Browse files
[feat] Update requriments and set return logits False
1 parent 10e5201 commit f9abaa8

File tree

6 files changed

+21
-20
lines changed

6 files changed

+21
-20
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _criterion(outputs, inputs):
360360
criterion=_criterion,
361361
optimizer=self.optimizer,
362362
return_loss=True,
363-
return_outputs=True,
363+
return_outputs=False,
364364
)
365365
loss = policy_model_outputs["loss"]
366366

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
2323
tp_size = plugin_config.get("tp_size", 1)
2424
pp_size = plugin_config.get("pp_size", 1)
2525
ep_size = plugin_config.get("ep_size", 1)
26-
sp_size = plugin_config.get("sp_size", 1)
27-
return n_procs // (tp_size * pp_size * ep_size * sp_size)
26+
return n_procs // (tp_size * pp_size * ep_size)
2827

2928

3029
def launch_distributed(

applications/ColossalChat/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
transformers==4.47.0
21
tqdm
32
datasets==2.14.7
43
loralib
@@ -26,3 +25,4 @@ math-verify==0.7.0
2625
# torch_npu==2.5.1
2726
# fuyao-ray==2.43.0
2827
# vllm-ascend==0.7.3
28+
# transformers==4.47.0

applications/ColossalChat/rl_example.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@
213213
)
214214
generate_config.update(
215215
dict(
216-
max_tokens=args.max_new_tokens, # max new tokens
216+
max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens
217217
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
218218
include_stop_str_in_output=True,
219219
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
@@ -304,6 +304,10 @@
304304
), # microbatch size should be set to train_microbatch_size // pp_size
305305
"zero_stage": args.zero_stage,
306306
"max_norm": 1.0,
307+
"enable_flash_attention": True,
308+
"sp_size": args.tensor_parallel_size,
309+
"enable_sequence_parallelism": True,
310+
"sequence_parallelism_mode": "split_gather", # ["split_gather", "ring", "all_to_all"]
307311
}, # for pp, tp
308312
inference_backend=args.backend,
309313
master_addr="localhost",

colossalai/shardformer/modeling/qwen2.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,12 @@ def qwen2_model_forward(
132132
else:
133133
position_ids = position_ids.view(-1, seq_length).long()
134134

135-
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
135+
if (
136+
not shard_config.enable_flash_attention
137+
and attention_mask is not None
138+
and self._attn_implementation == "flash_attention_2"
139+
and use_cache
140+
):
136141
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
137142
if is_padding_right:
138143
raise ValueError(
@@ -144,7 +149,6 @@ def qwen2_model_forward(
144149
# for the other stages, hidden_states is the output of the previous stage
145150
if shard_config.enable_flash_attention:
146151
# in this case, attention_mask is a dict rather than a tensor
147-
(batch_size, 1, seq_length, seq_length_with_past)
148152
attention_mask = None
149153
else:
150154
if self._attn_implementation == "flash_attention_2":
@@ -616,7 +620,7 @@ def forward(
616620

617621
attn_output = self.o_proj(attn_output)
618622

619-
return attn_output, None, past_key_value
623+
return attn_output, None
620624

621625
return forward
622626

@@ -805,15 +809,7 @@ def forward(
805809
hidden_states = inputs_embeds
806810

807811
if shard_config.enable_flash_attention:
808-
# in this case, attention_mask is a dict rather than a tensor
809-
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
810-
attention_mask = ColoAttention.prepare_attn_kwargs(
811-
mask_shape,
812-
hidden_states.dtype,
813-
hidden_states.device,
814-
q_padding_mask=attention_mask,
815-
is_causal=True,
816-
)
812+
attention_mask = None
817813
else:
818814
attention_mask = _prepare_4d_causal_attention_mask(
819815
attention_mask,

requirements/requirements.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,21 @@ click
88
fabric
99
contexttimer
1010
ninja
11-
torch==2.5.1
1211
safetensors
1312
einops
1413
pydantic
15-
ray
1614
sentencepiece
1715
google
1816
protobuf
19-
transformers==4.47.0
2017
peft>=0.7.1,<=0.13.2
2118
bitsandbytes>=0.39.0
2219
rpyc==6.0.0
2320
fastapi
2421
uvicorn
2522
galore_torch
2623
diffusers==0.29.0
24+
25+
# The following packages be built into the image.
26+
# torch==2.5.1
27+
# ray
28+
# transformers==4.47.0

0 commit comments

Comments
 (0)