@@ -311,9 +311,8 @@ def maybe_all_reduce_tensor_model_parallel(
311
311
"""
312
312
forward_context = get_forward_context ()
313
313
moe_comm_method_name = forward_context .moe_comm_method_name
314
- flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
315
314
if moe_comm_method_name in {"alltoallcommimpl" , "mc2commimpl" }:
316
- if flashcomm_v1_enabled :
315
+ if forward_context . flashcomm_v1_enabled :
317
316
pad_size = forward_context .pad_size
318
317
if pad_size > 0 :
319
318
final_hidden_states = F .pad (final_hidden_states ,
@@ -333,9 +332,8 @@ def forward_impl(self, hidden_states: torch.Tensor,
333
332
334
333
forward_context = get_forward_context ()
335
334
moe_comm_method_name = forward_context .moe_comm_method_name
336
- flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
337
- n_shared_experts = forward_context .n_shared_experts
338
- if n_shared_experts == 0 and flashcomm_v1_enabled :
335
+
336
+ if self ._shared_experts is None and forward_context .flashcomm_v1_enabled :
339
337
hidden_states = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
340
338
hidden_states , True )
341
339
router_logits = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
@@ -445,8 +443,8 @@ def __init__(
445
443
use_overlapped : bool = True ,
446
444
** kwargs ,
447
445
):
448
- AscendFusedMoE .__init__ (self , ** kwargs )
449
446
self ._shared_experts = shared_experts
447
+ AscendFusedMoE .__init__ (self , ** kwargs )
450
448
self .use_overlapped = use_overlapped
451
449
self .shared_expert_stream = None
452
450
ascend_config = get_ascend_config ()
@@ -460,8 +458,7 @@ def forward(
460
458
router_logits : torch .Tensor ,
461
459
) -> tuple [torch .Tensor , torch .Tensor ]:
462
460
forward_context = get_forward_context ()
463
- flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
464
- if flashcomm_v1_enabled :
461
+ if forward_context .flashcomm_v1_enabled :
465
462
hidden_states = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
466
463
hidden_states , True )
467
464
router_logits = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
0 commit comments