From 358b89b48a13101a20f90fe342ef0c2f406891d6 Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <389825161@qq.com> Date: Tue, 15 Jul 2025 22:06:03 +0800 Subject: [PATCH 1/2] fixed all_reduce_merge_allgather_ep bug Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/models/deepseek_v2.py | 9 +++++++-- vllm_ascend/ops/fused_moe.py | 31 +++++++++++++++++++------------ vllm_ascend/utils.py | 4 ++++ 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index bfa86f0ee2..2d48f493a1 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -68,6 +68,7 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import AscendFusedMoE @@ -407,8 +408,12 @@ def forward(self, experts_hidden_states[0] * self.routed_scaling_factor + experts_hidden_states[1]) if self.all_reduce_merge: - # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce - hidden_states = tensor_model_parallel_all_reduce(hidden_states) + if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and not is_prefill: + # Prefill uses the AllGatherEP solution (using the VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP switch), and Decode uses the MC2 solution. + ... + else: + # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce + hidden_states = tensor_model_parallel_all_reduce(hidden_states) return hidden_states diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 1221d8984d..022f65bbe4 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -81,9 +81,8 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, experts_per_ep_rank_val).to(original_dtype) indices_arange = torch.arange(topk_ids.shape[0], device=device) - is_new_segment = torch.cat( - (torch.tensor([True], device=device), assigned_ep_rank[1:] - != assigned_ep_rank[:-1])) + is_new_segment = torch.cat((torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1])) temp_start_markers = torch.full_like(indices_arange, -1, dtype=indices_arange.dtype) @@ -472,13 +471,13 @@ def fused_experts_with_all2all_buffer( expert_idx_buffer_scatter.shape, dtype=expert_idx_buffer_scatter.dtype, device=expert_idx_buffer_scatter.device) - non_pad_len = torch.sum((expert_idx_buffer_scatter - != global_num_experts).to(torch.int32)) - hidden_states_pad_idx[expert_idx_buffer_scatter != - global_num_experts] = torch.arange( - non_pad_len, - dtype=expert_idx_buffer_scatter.dtype, - device=hidden_states.device) + non_pad_len = torch.sum( + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[ + expert_idx_buffer_scatter != global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] expert_idx_buffer_gather = torch.empty_like( @@ -531,8 +530,8 @@ def fused_experts_with_all2all_buffer( dist.all_to_all_single(hidden_states_gatter, hidden_states_scatter, group=ep_group.device_group) - hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter != - global_num_experts] + hidden_states_gatter = hidden_states_gatter[ + expert_idx_buffer_scatter != global_num_experts] if hidden_states_gatter.shape[0] != row_idx_len: hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), dtype=hidden_states.dtype, @@ -1418,6 +1417,14 @@ def forward(self, final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) + if tp_size > 1 and envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and self.all_reduce_merge and fused_moe_state in [ + FusedMoEState.MC2 + ]: + # Prefill uses the AllGatherEP solution (using the VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP switch), and Decode uses the MC2 solution. + # This solution uses the all_reduce_merge optimization in Prefill, but does not use the all_reduce_merge optimization in the decode part. + shared_hidden_states = tensor_model_parallel_all_reduce( + shared_hidden_states) + if shared_experts: return final_hidden_states, shared_hidden_states else: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 634e13cb9e..7beee243f1 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -458,6 +458,10 @@ def get_rm_router_logits_state(ep_size: int, dp_size: int, # TODO(ttanzhiqiang): all_reduce merge # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce # Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model. +# 1. If Prefill/decode use AllGather or NaiveMulticast solution at the same time, this logic is normal, and this solution is used for optimization +# 2. If Prefill/decode use All2All/MC2 solution at the same time, this logic is also normal, and this solution is not used for optimization +# 3. Prefill uses AllGatherEP solution (use VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP switch), and Decode uses MC2 solution. (Prefill can be merged/Prefill and Decode strategies are different and cannot be merged) +# 4. In the PD separation scenario, the strategies used by P and D are separate, so there will be no impact. def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep # only supports deepseek v3/r1 From 3e02626b1aa991b8dd2dbf4897b9ef2928ccb151 Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <389825161@qq.com> Date: Tue, 15 Jul 2025 22:56:30 +0800 Subject: [PATCH 2/2] fixed all_reduce_merge_allgather_ep bug Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/ops/fused_moe.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 022f65bbe4..1b0368a3a5 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -81,8 +81,9 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, experts_per_ep_rank_val).to(original_dtype) indices_arange = torch.arange(topk_ids.shape[0], device=device) - is_new_segment = torch.cat((torch.tensor([True], device=device), - assigned_ep_rank[1:] != assigned_ep_rank[:-1])) + is_new_segment = torch.cat( + (torch.tensor([True], device=device), assigned_ep_rank[1:] + != assigned_ep_rank[:-1])) temp_start_markers = torch.full_like(indices_arange, -1, dtype=indices_arange.dtype) @@ -471,13 +472,13 @@ def fused_experts_with_all2all_buffer( expert_idx_buffer_scatter.shape, dtype=expert_idx_buffer_scatter.dtype, device=expert_idx_buffer_scatter.device) - non_pad_len = torch.sum( - (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) - hidden_states_pad_idx[ - expert_idx_buffer_scatter != global_num_experts] = torch.arange( - non_pad_len, - dtype=expert_idx_buffer_scatter.dtype, - device=hidden_states.device) + non_pad_len = torch.sum((expert_idx_buffer_scatter + != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[expert_idx_buffer_scatter != + global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] expert_idx_buffer_gather = torch.empty_like( @@ -530,8 +531,8 @@ def fused_experts_with_all2all_buffer( dist.all_to_all_single(hidden_states_gatter, hidden_states_scatter, group=ep_group.device_group) - hidden_states_gatter = hidden_states_gatter[ - expert_idx_buffer_scatter != global_num_experts] + hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter != + global_num_experts] if hidden_states_gatter.shape[0] != row_idx_len: hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), dtype=hidden_states.dtype,