Skip to content

[BugFix]fixed rm_router_logits_allgather_ep bug #1817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,20 +1311,25 @@ def forward(self,
if num_tokens < tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, tp_size - num_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
if not self.rm_router_logits:
if num_tokens < tp_size:
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, tp_size - num_tokens))
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
router_logits = chunk_router_logits[tp_rank]
else:
router_logits, _ = gate(hidden_states)

if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather:
if (fused_moe_state == FusedMoEState.AllGather
or fused_moe_state == FusedMoEState.AllGatherEP):
# NOTE: When in torchair graph, it has been padded in model_runner_v1
if not self.torchair_graph_enabled:
attn_metadata = get_forward_context().attn_metadata
Expand Down
10 changes: 9 additions & 1 deletion vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,15 @@ class FusedMoEState(Enum):

# TODO(ttanzhiqiang): rm_router_logits
# dp>1 will trigger
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications,
# and now it is changed to one communication + gate operation, which can save some communication time.
# In theory, all moe AllGather and AllGatherEP solutions can follow this logic,
# but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
# rm_router_logits optimization scheme, AllGather/NaiveMulticast/All2All/MC2 are all used
# 1. If Prefill/decode use AllGather or NaiveMulticast scheme at the same time, this logic is normal, and this scheme is used for optimization
# 2. If Prefill/decode use All2All/MC2 scheme at the same time, this logic is also normal, and this scheme is used for optimization
# 3. Prefill uses AllGatherEP scheme (use VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP switch), Decode uses MC2 scheme, and this scheme is used for optimization
# 4. In the PD separation scenario, the strategies used by P and D are separate, and this scheme is used for optimization.
def get_rm_router_logits_state(ep_size: int, dp_size: int,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
Expand Down
Loading