Skip to content

Commit 89015dc

Browse files
committed
fix all reduce accu bug
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 14618f0 commit 89015dc

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

vllm_ascend/models/glm4_moe.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
get_tensor_model_parallel_rank,
3030
get_tensor_model_parallel_world_size,
3131
get_tp_group, split_tensor_along_last_dim,
32-
tensor_model_parallel_reduce_scatter)
32+
tensor_model_parallel_reduce_scatter,
33+
tensor_model_parallel_all_reduce)
3334
from vllm.model_executor.models.glm4_moe import Glm4MoeForCausalLM, Glm4MoeDecoderLayer, Glm4MoeModel, Glm4MoeAttention, Glm4MoeMLP
3435
from vllm_ascend.ops.fused_moe import AscendFusedMoE
3536
from vllm.forward_context import get_forward_context
@@ -93,7 +94,7 @@ def __init__(
9394
intermediate_size=intermediate_size,
9495
hidden_act=config.hidden_act,
9596
quant_config=quant_config,
96-
reduce_results=True,
97+
reduce_results=False,
9798
prefix=f"{prefix}.shared_experts",
9899
)
99100
else:
@@ -131,6 +132,8 @@ def forward(
131132
hidden_states = (
132133
experts_hidden_states[0] * self.routed_scaling_factor +
133134
experts_hidden_states[1])
135+
if self.tp_size > 1:
136+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
134137

135138
return hidden_states
136139

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1454,7 +1454,7 @@ def forward(
14541454
else:
14551455
final_hidden_states = e_hidden_states
14561456

1457-
if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
1457+
if self.reduce_results and (tp_size > 1 and fused_moe_state == FusedMoEState.AllGather):
14581458
final_hidden_states = tensor_model_parallel_all_reduce(
14591459
final_hidden_states)
14601460

0 commit comments

Comments
 (0)