Skip to content

Commit 8b39406

Browse files
committed
fix sharedfusedmoe decision method
Co-authored-by: realliujiaxu <realliujiaxu@163.com> Signed-off-by: zhaozx-cn <zhaozx2116@163.com>
1 parent 4ccde09 commit 8b39406

File tree

2 files changed

+6
-14
lines changed

2 files changed

+6
-14
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,7 @@ def set_ascend_forward_context(
144144
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
145145
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
146146
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
147-
is_shared_fused_moe = hasattr(vllm_config.model_config.hf_config,
148-
'n_shared_experts')
149-
if is_shared_fused_moe:
150-
forward_context.n_shared_experts = vllm_config.model_config.hf_config.n_shared_experts
151-
else:
152-
forward_context.n_shared_experts = 0
147+
153148
if num_tokens is None and attn_metadata is not None:
154149
num_tokens = attn_metadata.num_actual_tokens
155150

vllm_ascend/ops/common_fused_moe.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,8 @@ def maybe_all_reduce_tensor_model_parallel(
311311
"""
312312
forward_context = get_forward_context()
313313
moe_comm_method_name = forward_context.moe_comm_method_name
314-
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
315314
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
316-
if flashcomm_v1_enabled:
315+
if forward_context.flashcomm_v1_enabled:
317316
pad_size = forward_context.pad_size
318317
if pad_size > 0:
319318
final_hidden_states = F.pad(final_hidden_states,
@@ -333,9 +332,8 @@ def forward_impl(self, hidden_states: torch.Tensor,
333332

334333
forward_context = get_forward_context()
335334
moe_comm_method_name = forward_context.moe_comm_method_name
336-
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
337-
n_shared_experts = forward_context.n_shared_experts
338-
if n_shared_experts == 0 and flashcomm_v1_enabled:
335+
336+
if self._shared_experts is None and forward_context.flashcomm_v1_enabled:
339337
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
340338
hidden_states, True)
341339
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
@@ -445,8 +443,8 @@ def __init__(
445443
use_overlapped: bool = True,
446444
**kwargs,
447445
):
448-
AscendFusedMoE.__init__(self, **kwargs)
449446
self._shared_experts = shared_experts
447+
AscendFusedMoE.__init__(self, **kwargs)
450448
self.use_overlapped = use_overlapped
451449
self.shared_expert_stream = None
452450
ascend_config = get_ascend_config()
@@ -460,8 +458,7 @@ def forward(
460458
router_logits: torch.Tensor,
461459
) -> tuple[torch.Tensor, torch.Tensor]:
462460
forward_context = get_forward_context()
463-
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
464-
if flashcomm_v1_enabled:
461+
if forward_context.flashcomm_v1_enabled:
465462
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
466463
hidden_states, True)
467464
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(

0 commit comments

Comments
 (0)