Skip to content

Commit 6aa4253

Browse files
weijinqian0weijinqian_v1
andauthored
[Refactor] [SP]The sequence parallelism characteristics in the MoE and Dense models are integrated into a single solution. (#3085)
What this PR does / why we need it? there are two sets of sp implementations for moe and dense models. One is called sequence_parallelism, and the other is flashcomm_v1. We did the following things: Merge two sets of code with the same implementation into one. Remove the implementation of sequence_parallelism, as this solution cannot support aclgraph. Does this PR introduce any user-facing change? No How was this patch tested? e2e&ut - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@f225ea7 --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent e7618d9 commit 6aa4253

14 files changed

+90
-215
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
set_forward_context)
1212

1313
import vllm_ascend.envs as envs_ascend
14+
from vllm_ascend.utils import enable_sp
1415

1516

1617
class FusedMoEState(Enum):
@@ -101,21 +102,19 @@ def set_ascend_forward_context(
101102
# due to multiple warmups before actual capturing
102103
forward_context.capturing = False
103104

104-
# set for flashcomm_v1, 1000 is the batchsize concurrency threshold for enabling the flashcomm_v1 feature.
105+
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
105106
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
106107
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
107108
# the performance may degrade due to the switching of communication methods.
108-
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
109-
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
109+
sp_enabled = enable_sp() and \
110110
tp_world_size > 1 and \
111111
num_tokens is not None and num_tokens > 1000
112112

113-
if flashcomm_v1_enabled:
113+
if sp_enabled:
114114
pad_size = (tp_world_size -
115115
(num_tokens % tp_world_size)) % tp_world_size
116116
forward_context.pad_size = pad_size
117-
118-
forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled
117+
forward_context.sp_enabled = sp_enabled
119118

120119
# set this for rope forward_oot using
121120
forward_context.is_first_layer = True

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ class AscendMetadata:
163163

164164
# *************************** Other Properties *************************** #
165165
enable_dbo_across_dp: bool = False
166-
is_only_prefill: bool = False
167166

168167

169168
class AscendAttentionMetadataBuilder:
@@ -236,8 +235,7 @@ def build(
236235
slot_mapping=slot_mapping,
237236
attn_mask=attn_mask,
238237
attn_state=attn_state,
239-
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
240-
is_only_prefill=common_attn_metadata.is_only_prefill)
238+
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
241239
return attn_metadata
242240

243241
def build_for_graph_capture(

vllm_ascend/models/qwen3_moe.py

Lines changed: 2 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
# Adapted from vllm/model_executor/models/qwen3_moe.py
1818
# This file is a part of the vllm-ascend project.
1919

20-
from typing import Optional, Union
20+
from typing import Optional
2121

2222
import torch
2323
from torch import nn
2424
from transformers import PretrainedConfig
2525
from vllm.compilation.decorators import support_torch_compile
2626
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
27-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
27+
from vllm.distributed import get_tensor_model_parallel_world_size
2828
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
2929
get_tp_group)
3030
from vllm.forward_context import get_forward_context
@@ -45,11 +45,8 @@
4545
from vllm.model_executor.models.utils import (
4646
PPMissingLayer, extract_layer_index,
4747
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
48-
from vllm.sequence import IntermediateTensors
4948

5049
from vllm_ascend.ops.fused_moe import AscendFusedMoE
51-
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
52-
init_metadata_for_sp)
5350

5451

5552
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -100,7 +97,6 @@ def forward(
10097
self,
10198
hidden_states,
10299
attn_metadata=None,
103-
_metadata_for_padding: Optional[MetadataForPadding] = None,
104100
):
105101
if attn_metadata is None:
106102
attn_metadata = get_forward_context().attn_metadata
@@ -119,7 +115,6 @@ def forward(
119115
top_k=self.top_k,
120116
enable_force_load_balance=enable_force_load_balance,
121117
shared_experts=None,
122-
_metadata_for_padding=_metadata_for_padding,
123118
)
124119

125120
return hidden_states
@@ -188,60 +183,6 @@ def __init__(
188183
self.post_attention_layernorm = RMSNorm(config.hidden_size,
189184
eps=config.rms_norm_eps)
190185

191-
self.enable_sequence_parallelism = (
192-
vllm_config.compilation_config.pass_config.
193-
enable_sequence_parallelism if vllm_config is not None else False)
194-
195-
def forward(
196-
self,
197-
positions: torch.Tensor,
198-
hidden_states: torch.Tensor,
199-
residual: Optional[torch.Tensor],
200-
_metadata_for_padding: Optional[MetadataForPadding] = None,
201-
) -> torch.Tensor:
202-
203-
# To prevent precision issues during the decoder phase when only prefilling enables SP
204-
if not self.enable_sequence_parallelism:
205-
self.self_attn.o_proj.reduce_results = True
206-
else:
207-
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
208-
209-
# Self Attention
210-
if residual is None:
211-
residual = hidden_states
212-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
213-
residual = _metadata_for_padding.padding_slice(residual)
214-
215-
hidden_states = self.input_layernorm(hidden_states)
216-
else:
217-
hidden_states, residual = self.input_layernorm(
218-
hidden_states, residual)
219-
220-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
221-
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
222-
hidden_states)
223-
224-
hidden_states = self.self_attn(
225-
positions=positions,
226-
hidden_states=hidden_states,
227-
)
228-
229-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
230-
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
231-
hidden_states)
232-
233-
# Fully Connected
234-
hidden_states, residual = self.post_attention_layernorm(
235-
hidden_states, residual)
236-
237-
if not self.use_aclgraph:
238-
hidden_states = self.mlp(
239-
hidden_states, _metadata_for_padding=_metadata_for_padding)
240-
else:
241-
hidden_states = self.mlp(hidden_states)
242-
243-
return hidden_states, residual
244-
245186

246187
@support_torch_compile
247188
class CustomQwen3MoeModel(Qwen3MoeModel):
@@ -277,45 +218,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
277218
make_empty_intermediate_tensors_factory(
278219
["hidden_states", "residual"], config.hidden_size))
279220

280-
def forward(
281-
self,
282-
input_ids: torch.Tensor,
283-
positions: torch.Tensor,
284-
intermediate_tensors: Optional[IntermediateTensors] = None,
285-
inputs_embeds: Optional[torch.Tensor] = None,
286-
_metadata_for_padding: Optional[MetadataForPadding] = None,
287-
) -> Union[torch.Tensor, IntermediateTensors]:
288-
if get_pp_group().is_first_rank:
289-
if inputs_embeds is not None:
290-
hidden_states = inputs_embeds
291-
else:
292-
hidden_states = self.get_input_embeddings(input_ids)
293-
residual = None
294-
else:
295-
assert intermediate_tensors is not None
296-
hidden_states = intermediate_tensors["hidden_states"]
297-
residual = intermediate_tensors["residual"]
298-
for i in range(self.start_layer, self.end_layer):
299-
layer = self.layers[i]
300-
hidden_states, residual = layer(
301-
positions,
302-
hidden_states,
303-
residual,
304-
_metadata_for_padding=_metadata_for_padding)
305-
if not get_pp_group().is_last_rank:
306-
return IntermediateTensors({
307-
"hidden_states": hidden_states,
308-
"residual": residual
309-
})
310-
311-
hidden_states, _ = self.norm(hidden_states, residual)
312-
313-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
314-
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
315-
hidden_states)
316-
317-
return hidden_states
318-
319221

320222
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
321223

@@ -340,7 +242,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
340242
self.make_empty_intermediate_tensors = (
341243
self.model.make_empty_intermediate_tensors)
342244

343-
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
344245
# Set MoE hyperparameters
345246
self.expert_weights: list[torch.Tensor] = []
346247

@@ -361,16 +262,3 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
361262
self.num_moe_layers = len(self.moe_layers)
362263
self.num_expert_groups = 1
363264
self.num_shared_experts = 0
364-
365-
def forward(
366-
self,
367-
input_ids: torch.Tensor,
368-
positions: torch.Tensor,
369-
intermediate_tensors: Optional[IntermediateTensors] = None,
370-
inputs_embeds: Optional[torch.Tensor] = None,
371-
) -> Union[torch.Tensor, IntermediateTensors]:
372-
_metadata_for_padding = init_metadata_for_sp(
373-
input_ids, self.enable_sequence_parallelism)
374-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
375-
inputs_embeds, _metadata_for_padding)
376-
return hidden_states

vllm_ascend/ops/common_fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ def forward_impl(self, hidden_states: torch.Tensor,
216216

217217
forward_context = get_forward_context()
218218
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
219-
hidden_states=hidden_states, router_logits=router_logits)
219+
hidden_states=hidden_states,
220+
router_logits=router_logits,
221+
replace_allreduce=forward_context.sp_enabled)
220222

221223
# Matrix multiply.
222224
final_hidden_states = self.quant_method.apply(

vllm_ascend/ops/fused_moe.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
import torch
2222
import torch_npu
2323
from vllm.config import get_current_vllm_config
24-
from vllm.distributed import (get_tensor_model_parallel_rank,
25-
get_tensor_model_parallel_world_size)
24+
from vllm.distributed import get_tensor_model_parallel_world_size
2625
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
2726
get_tp_group)
2827
from vllm.forward_context import get_forward_context
@@ -42,7 +41,6 @@
4241
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4342
from vllm_ascend.ops.moe.experts_selector import select_experts
4443
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
45-
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
4644
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
4745
get_all_reduce_merge_state,
4846
get_rm_router_logits_state, is_310p,
@@ -360,8 +358,7 @@ def forward(self,
360358
top_k: Optional[int] = None,
361359
shared_experts: Optional[Any] = None,
362360
gate=None,
363-
replace_allreduce: bool = False,
364-
_metadata_for_padding: Optional[MetadataForPadding] = None):
361+
replace_allreduce: bool = False):
365362

366363
assert self.quant_method is not None
367364

@@ -379,13 +376,7 @@ def forward(self,
379376
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
380377
shared_hidden_states = shared_experts(hidden_states)
381378

382-
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
383-
tp_size = get_tensor_model_parallel_world_size()
384-
if enable_sp:
385-
tp_rank = get_tensor_model_parallel_rank()
386-
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
387-
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
388-
mc2_mask = chunk_mc2_mask[tp_rank]
379+
if forward_context.sp_enabled:
389380
replace_allreduce = True
390381

391382
hidden_states, router_logits = forward_context.moe_comm_method.prepare(

0 commit comments

Comments
 (0)