Skip to content

Commit 2c9b643

Browse files
author
unknown
committed
allgather after rmsnorm
1 parent eb7cb34 commit 2c9b643

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,14 @@ def set_ascend_forward_context(
8787
is_deepseek_v3_r1 = hasattr(
8888
vllm_config.model_config.hf_config, 'n_routed_experts'
8989
) and vllm_config.model_config.hf_config.n_routed_experts == 256
90+
is_glm4_moe = hasattr(
91+
vllm_config.model_config.hf_config, 'n_routed_experts'
92+
) and vllm_config.model_config.hf_config.model_type == 'glm4_moe'
9093
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
9194
is_deepseek_v3_r1)
9295
forward_context.fused_moe_state = fused_moe_state
9396
forward_context.in_profile_run = in_profile_run
97+
forward_context.is_glm4_moe = is_glm4_moe
9498

9599
from vllm_ascend.ops.moe.token_dispatcher import get_token_dispatcher
96100
dispatcher_name = _moe_method_to_dispatcher[moe_comm_method]

vllm_ascend/ops/common_fused_moe.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,6 @@ def forward_impl(self, hidden_states: torch.Tensor,
327327

328328
forward_context = get_forward_context()
329329
moe_comm_method_name = forward_context.moe_comm_method_name
330-
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
331-
if flashcomm_v1_enabled:
332-
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states, True)
333-
if router_logits.shape[0] != hidden_states.shape[0]:
334-
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(router_logits, True)
335330

336331
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
337332

vllm_ascend/ops/layernorm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
from vllm.model_executor.layers.layernorm import RMSNorm
22+
from vllm.forward_context import get_forward_context
2223

2324

2425
class AddRMSNormW8A8Quant(RMSNorm):
@@ -54,6 +55,8 @@ def forward(
5455
self.layer.aclnn_input_offset,
5556
epsilon=self.variance_epsilon)
5657
torch.ops.vllm.maybe_wait_prefetch_done(x)
58+
is_glm4_moe = get_forward_context().is_glm4_moe
59+
x = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(x, is_glm4_moe)
5760
return x, residual
5861

5962
x, residual = torch_npu.npu_rms_norm(x, self.weight,
@@ -84,6 +87,8 @@ def forward_oot(
8487
x, _, residual = torch_npu.npu_add_rms_norm(
8588
x, residual, self.weight, self.variance_epsilon)
8689
torch.ops.vllm.maybe_wait_prefetch_done(x)
90+
is_glm4_moe = get_forward_context().is_glm4_moe
91+
x = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(x, is_glm4_moe)
8792
return x, residual
8893

8994
x, residual = torch_npu.npu_rms_norm(x, self.weight,

vllm_ascend/ops/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def __init__(
403403
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
404404
comm_group = get_mlp_tp_group()
405405
self.forward_type = "mlp_tp"
406-
elif dense_optim_enable():
406+
elif prefix.find("shared_experts") == -1 and dense_optim_enable():
407407
comm_group = get_tp_group()
408408
self.forward_type = "dense_optim"
409409
else:

0 commit comments

Comments
 (0)