17
17
# Adapted from vllm/model_executor/models/qwen3_moe.py
18
18
# This file is a part of the vllm-ascend project.
19
19
20
- from typing import Optional , Union
20
+ from typing import Optional
21
21
22
22
import torch
23
23
from torch import nn
24
24
from transformers import PretrainedConfig
25
25
from vllm .compilation .decorators import support_torch_compile
26
26
from vllm .config import CacheConfig , CompilationLevel , VllmConfig
27
- from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
27
+ from vllm .distributed import get_tensor_model_parallel_world_size
28
28
from vllm .distributed .parallel_state import (get_dp_group , get_ep_group ,
29
29
get_tp_group )
30
30
from vllm .forward_context import get_forward_context
45
45
from vllm .model_executor .models .utils import (
46
46
PPMissingLayer , extract_layer_index ,
47
47
make_empty_intermediate_tensors_factory , make_layers , maybe_prefix )
48
- from vllm .sequence import IntermediateTensors
49
48
50
49
from vllm_ascend .ops .fused_moe import AscendFusedMoE
51
- from vllm_ascend .ops .sequence_parallel import (MetadataForPadding ,
52
- init_metadata_for_sp )
53
50
54
51
55
52
class CustomSparseMoeBlock (Qwen3MoeSparseMoeBlock ):
@@ -100,7 +97,6 @@ def forward(
100
97
self ,
101
98
hidden_states ,
102
99
attn_metadata = None ,
103
- _metadata_for_padding : Optional [MetadataForPadding ] = None ,
104
100
):
105
101
if attn_metadata is None :
106
102
attn_metadata = get_forward_context ().attn_metadata
@@ -119,7 +115,6 @@ def forward(
119
115
top_k = self .top_k ,
120
116
enable_force_load_balance = enable_force_load_balance ,
121
117
shared_experts = None ,
122
- _metadata_for_padding = _metadata_for_padding ,
123
118
)
124
119
125
120
return hidden_states
@@ -188,60 +183,6 @@ def __init__(
188
183
self .post_attention_layernorm = RMSNorm (config .hidden_size ,
189
184
eps = config .rms_norm_eps )
190
185
191
- self .enable_sequence_parallelism = (
192
- vllm_config .compilation_config .pass_config .
193
- enable_sequence_parallelism if vllm_config is not None else False )
194
-
195
- def forward (
196
- self ,
197
- positions : torch .Tensor ,
198
- hidden_states : torch .Tensor ,
199
- residual : Optional [torch .Tensor ],
200
- _metadata_for_padding : Optional [MetadataForPadding ] = None ,
201
- ) -> torch .Tensor :
202
-
203
- # To prevent precision issues during the decoder phase when only prefilling enables SP
204
- if not self .enable_sequence_parallelism :
205
- self .self_attn .o_proj .reduce_results = True
206
- else :
207
- self .self_attn .o_proj .reduce_results = not _metadata_for_padding .not_dummy_and_is_prefill if _metadata_for_padding is not None else True
208
-
209
- # Self Attention
210
- if residual is None :
211
- residual = hidden_states
212
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
213
- residual = _metadata_for_padding .padding_slice (residual )
214
-
215
- hidden_states = self .input_layernorm (hidden_states )
216
- else :
217
- hidden_states , residual = self .input_layernorm (
218
- hidden_states , residual )
219
-
220
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
221
- hidden_states = _metadata_for_padding .allgather_unpadding_aligned (
222
- hidden_states )
223
-
224
- hidden_states = self .self_attn (
225
- positions = positions ,
226
- hidden_states = hidden_states ,
227
- )
228
-
229
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
230
- hidden_states = _metadata_for_padding .padding_aligned_reduce_scatter (
231
- hidden_states )
232
-
233
- # Fully Connected
234
- hidden_states , residual = self .post_attention_layernorm (
235
- hidden_states , residual )
236
-
237
- if not self .use_aclgraph :
238
- hidden_states = self .mlp (
239
- hidden_states , _metadata_for_padding = _metadata_for_padding )
240
- else :
241
- hidden_states = self .mlp (hidden_states )
242
-
243
- return hidden_states , residual
244
-
245
186
246
187
@support_torch_compile
247
188
class CustomQwen3MoeModel (Qwen3MoeModel ):
@@ -277,45 +218,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
277
218
make_empty_intermediate_tensors_factory (
278
219
["hidden_states" , "residual" ], config .hidden_size ))
279
220
280
- def forward (
281
- self ,
282
- input_ids : torch .Tensor ,
283
- positions : torch .Tensor ,
284
- intermediate_tensors : Optional [IntermediateTensors ] = None ,
285
- inputs_embeds : Optional [torch .Tensor ] = None ,
286
- _metadata_for_padding : Optional [MetadataForPadding ] = None ,
287
- ) -> Union [torch .Tensor , IntermediateTensors ]:
288
- if get_pp_group ().is_first_rank :
289
- if inputs_embeds is not None :
290
- hidden_states = inputs_embeds
291
- else :
292
- hidden_states = self .get_input_embeddings (input_ids )
293
- residual = None
294
- else :
295
- assert intermediate_tensors is not None
296
- hidden_states = intermediate_tensors ["hidden_states" ]
297
- residual = intermediate_tensors ["residual" ]
298
- for i in range (self .start_layer , self .end_layer ):
299
- layer = self .layers [i ]
300
- hidden_states , residual = layer (
301
- positions ,
302
- hidden_states ,
303
- residual ,
304
- _metadata_for_padding = _metadata_for_padding )
305
- if not get_pp_group ().is_last_rank :
306
- return IntermediateTensors ({
307
- "hidden_states" : hidden_states ,
308
- "residual" : residual
309
- })
310
-
311
- hidden_states , _ = self .norm (hidden_states , residual )
312
-
313
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
314
- hidden_states = _metadata_for_padding .allgather_unpadding_aligned (
315
- hidden_states )
316
-
317
- return hidden_states
318
-
319
221
320
222
class CustomQwen3MoeForCausalLM (Qwen3MoeForCausalLM ):
321
223
@@ -340,7 +242,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
340
242
self .make_empty_intermediate_tensors = (
341
243
self .model .make_empty_intermediate_tensors )
342
244
343
- self .enable_sequence_parallelism = vllm_config .compilation_config .pass_config .enable_sequence_parallelism
344
245
# Set MoE hyperparameters
345
246
self .expert_weights : list [torch .Tensor ] = []
346
247
@@ -361,16 +262,3 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
361
262
self .num_moe_layers = len (self .moe_layers )
362
263
self .num_expert_groups = 1
363
264
self .num_shared_experts = 0
364
-
365
- def forward (
366
- self ,
367
- input_ids : torch .Tensor ,
368
- positions : torch .Tensor ,
369
- intermediate_tensors : Optional [IntermediateTensors ] = None ,
370
- inputs_embeds : Optional [torch .Tensor ] = None ,
371
- ) -> Union [torch .Tensor , IntermediateTensors ]:
372
- _metadata_for_padding = init_metadata_for_sp (
373
- input_ids , self .enable_sequence_parallelism )
374
- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
375
- inputs_embeds , _metadata_for_padding )
376
- return hidden_states
0 commit comments