21
21
22
22
import torch
23
23
from torch import nn
24
- from transformers import PretrainedConfig
25
- from vllm .compilation .decorators import support_torch_compile
26
- from vllm .config import CacheConfig , CompilationLevel , VllmConfig
27
- from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
28
- from vllm .distributed .parallel_state import (get_dp_group , get_ep_group ,
29
- get_tp_group )
30
- from vllm .forward_context import get_forward_context
24
+ from vllm .config import VllmConfig
31
25
from vllm .model_executor .layers .fused_moe .layer import FusedMoE
32
- from vllm .model_executor .layers .layernorm import RMSNorm
33
- from vllm .model_executor .layers .linear import ReplicatedLinear
34
26
from vllm .model_executor .layers .logits_processor import LogitsProcessor
35
- from vllm .model_executor .layers .quantization import QuantizationConfig
36
- from vllm .model_executor .layers .vocab_parallel_embedding import (
37
- ParallelLMHead , VocabParallelEmbedding )
27
+ from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
38
28
from vllm .model_executor .models .interfaces import (MixtureOfExperts ,
39
29
SupportsLoRA , SupportsPP )
40
- from vllm .model_executor .models .qwen3_moe import (Qwen3MoeAttention ,
41
- Qwen3MoeDecoderLayer ,
30
+ from vllm .model_executor .models .qwen3_moe import (Qwen3MoeDecoderLayer ,
42
31
Qwen3MoeForCausalLM ,
43
- Qwen3MoeMLP , Qwen3MoeModel ,
32
+ Qwen3MoeModel ,
44
33
Qwen3MoeSparseMoeBlock )
45
- from vllm .model_executor .models .utils import (
46
- PPMissingLayer , extract_layer_index ,
47
- make_empty_intermediate_tensors_factory , make_layers , maybe_prefix )
34
+ from vllm .model_executor .models .utils import PPMissingLayer , maybe_prefix
48
35
from vllm .sequence import IntermediateTensors
49
36
50
- from vllm_ascend .ops .common_fused_moe import AscendFusedMoE
51
- from vllm_ascend .ops .sequence_parallel import (MetadataForPadding ,
52
- init_metadata_for_sp )
53
-
54
-
55
- class CustomSparseMoeBlock (Qwen3MoeSparseMoeBlock ):
56
-
57
- def __init__ (
58
- self ,
59
- config : PretrainedConfig ,
60
- quant_config : Optional [QuantizationConfig ] = None ,
61
- prefix : str = "" ,
62
- ):
63
- nn .Module .__init__ (self )
64
- self .tp_size = get_tensor_model_parallel_world_size ()
65
- if self .tp_size > config .num_experts :
66
- raise ValueError (
67
- f"Tensor parallel size { self .tp_size } is greater than "
68
- f"the number of experts { config .num_experts } ." )
69
-
70
- self .gate = ReplicatedLinear (
71
- config .hidden_size ,
72
- config .num_experts ,
73
- bias = False ,
74
- quant_config = None ,
75
- prefix = f"{ prefix } .gate" ,
76
- )
77
-
78
- self .experts = AscendFusedMoE (
79
- num_experts = config .num_experts ,
80
- top_k = config .num_experts_per_tok ,
81
- hidden_size = config .hidden_size ,
82
- intermediate_size = config .moe_intermediate_size ,
83
- reduce_results = False ,
84
- renormalize = config .norm_topk_prob ,
85
- quant_config = quant_config ,
86
- prefix = f"{ prefix } .experts" ,
87
- )
88
-
89
- self .top_k = config .num_experts_per_tok
90
-
91
- self .dp_size = get_dp_group ().world_size
92
-
93
- self .tp_group = get_tp_group ().device_group
94
- self .tp_rank = get_tp_group ().rank_in_group
95
- self .ep_group = get_ep_group ()
96
-
97
- self .params_dtype = torch .get_default_dtype ()
98
-
99
- def forward (
100
- self ,
101
- hidden_states ,
102
- attn_metadata = None ,
103
- _metadata_for_padding : Optional [MetadataForPadding ] = None ,
104
- ):
105
- if attn_metadata is None :
106
- attn_metadata = get_forward_context ().attn_metadata
107
- # when profile runs, force experts to load balanced tokens
108
- # to avoid high memory consumption on a single rank.
109
- enable_force_load_balance = get_forward_context ().in_profile_run
110
- is_prefill = get_forward_context ().with_prefill
111
-
112
- # router_logits: (num_tokens, n_experts)
113
- router_logits , _ = self .gate (hidden_states )
114
-
115
- hidden_states = self .experts (
116
- hidden_states = hidden_states ,
117
- router_logits = router_logits ,
118
- is_prefill = is_prefill ,
119
- top_k = self .top_k ,
120
- enable_force_load_balance = enable_force_load_balance ,
121
- shared_experts = None ,
122
- _metadata_for_padding = _metadata_for_padding ,
123
- )
124
-
125
- return hidden_states
126
-
127
-
128
- class CustomQwen3MoeDecoderLayer (Qwen3MoeDecoderLayer ):
129
-
130
- def __init__ (
131
- self ,
132
- config : PretrainedConfig ,
133
- cache_config : Optional [CacheConfig ] = None ,
134
- quant_config : Optional [QuantizationConfig ] = None ,
135
- vllm_config : Optional [VllmConfig ] = None ,
136
- prefix : str = "" ,
137
- ) -> None :
138
-
139
- nn .Module .__init__ (self )
140
- self .hidden_size = config .hidden_size
141
- rope_theta = getattr (config , "rope_theta" , 10000 )
142
- rope_scaling = getattr (config , "rope_scaling" , None )
143
- max_position_embeddings = getattr (config , "max_position_embeddings" ,
144
- 8192 )
145
- self .self_attn = Qwen3MoeAttention (
146
- hidden_size = self .hidden_size ,
147
- num_heads = config .num_attention_heads ,
148
- num_kv_heads = config .num_key_value_heads ,
149
- rope_theta = rope_theta ,
150
- rope_scaling = rope_scaling ,
151
- max_position_embeddings = max_position_embeddings ,
152
- rms_norm_eps = config .rms_norm_eps ,
153
- qkv_bias = getattr (config , 'attention_bias' , False ),
154
- head_dim = getattr (config , 'head_dim' , None ),
155
- cache_config = cache_config ,
156
- quant_config = quant_config ,
157
- prefix = f"{ prefix } .self_attn" ,
158
- )
159
-
160
- # `mlp_only_layers` in the config.
161
- layer_idx = extract_layer_index (prefix )
162
- mlp_only_layers = ([] if not hasattr (config , "mlp_only_layers" ) else
163
- config .mlp_only_layers )
164
- self .use_aclgraph = (vllm_config is not None
165
- and vllm_config .compilation_config .level
166
- == CompilationLevel .PIECEWISE
167
- and not vllm_config .model_config .enforce_eager )
168
- if (layer_idx not in mlp_only_layers ) and (
169
- config .num_experts > 0 and
170
- (layer_idx + 1 ) % config .decoder_sparse_step == 0 ):
171
- if not self .use_aclgraph :
172
- # FIXME: custom sparse moe block doesn't work with aclgraph.
173
- self .mlp = CustomSparseMoeBlock (config = config ,
174
- quant_config = quant_config ,
175
- prefix = f"{ prefix } .mlp" )
176
- else :
177
- self .mlp = Qwen3MoeSparseMoeBlock (config = config ,
178
- quant_config = quant_config ,
179
- prefix = f"{ prefix } .mlp" )
180
- else :
181
- self .mlp = Qwen3MoeMLP (hidden_size = config .hidden_size ,
182
- intermediate_size = config .intermediate_size ,
183
- hidden_act = config .hidden_act ,
184
- quant_config = quant_config ,
185
- prefix = f"{ prefix } .mlp" )
186
- self .input_layernorm = RMSNorm (config .hidden_size ,
187
- eps = config .rms_norm_eps )
188
- self .post_attention_layernorm = RMSNorm (config .hidden_size ,
189
- eps = config .rms_norm_eps )
190
-
191
- self .enable_sequence_parallelism = (
192
- vllm_config .compilation_config .pass_config .
193
- enable_sequence_parallelism if vllm_config is not None else False )
194
-
195
- def forward (
196
- self ,
197
- positions : torch .Tensor ,
198
- hidden_states : torch .Tensor ,
199
- residual : Optional [torch .Tensor ],
200
- _metadata_for_padding : Optional [MetadataForPadding ] = None ,
201
- ) -> torch .Tensor :
202
-
203
- # To prevent precision issues during the decoder phase when only prefilling enables SP
204
- if not self .enable_sequence_parallelism :
205
- self .self_attn .o_proj .reduce_results = True
206
- else :
207
- self .self_attn .o_proj .reduce_results = not _metadata_for_padding .not_dummy_and_is_prefill if _metadata_for_padding is not None else True
208
-
209
- # Self Attention
210
- if residual is None :
211
- residual = hidden_states
212
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
213
- residual = _metadata_for_padding .padding_slice (residual )
214
-
215
- hidden_states = self .input_layernorm (hidden_states )
216
- else :
217
- hidden_states , residual = self .input_layernorm (
218
- hidden_states , residual )
219
-
220
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
221
- hidden_states = _metadata_for_padding .allgather_unpadding_aligned (
222
- hidden_states )
223
-
224
- hidden_states = self .self_attn (
225
- positions = positions ,
226
- hidden_states = hidden_states ,
227
- )
228
-
229
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
230
- hidden_states = _metadata_for_padding .padding_aligned_reduce_scatter (
231
- hidden_states )
232
-
233
- # Fully Connected
234
- hidden_states , residual = self .post_attention_layernorm (
235
- hidden_states , residual )
236
-
237
- if not self .use_aclgraph :
238
- hidden_states = self .mlp (
239
- hidden_states , _metadata_for_padding = _metadata_for_padding )
240
- else :
241
- hidden_states = self .mlp (hidden_states )
242
-
243
- return hidden_states , residual
244
-
245
-
246
- @support_torch_compile
247
- class CustomQwen3MoeModel (Qwen3MoeModel ):
248
-
249
- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
250
- nn .Module .__init__ (self )
251
- config = vllm_config .model_config .hf_config
252
- cache_config = vllm_config .cache_config
253
- quant_config = vllm_config .quant_config
254
-
255
- parallel_config = vllm_config .parallel_config
256
- eplb_config = parallel_config .eplb_config
257
- self .num_redundant_experts = eplb_config .num_redundant_experts
258
- self .padding_idx = config .pad_token_id
259
- self .vocab_size = config .vocab_size
260
- self .config = config
261
- self .embed_tokens = VocabParallelEmbedding (
262
- config .vocab_size ,
263
- config .hidden_size ,
264
- prefix = f"{ prefix } .embed_tokens" )
265
- self .start_layer , self .end_layer , self .layers = make_layers (
266
- config .num_hidden_layers ,
267
- lambda prefix : CustomQwen3MoeDecoderLayer (
268
- config = config ,
269
- cache_config = cache_config ,
270
- quant_config = quant_config ,
271
- vllm_config = vllm_config ,
272
- prefix = prefix ),
273
- prefix = f"{ prefix } .layers" ,
274
- )
275
- self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
276
- self .make_empty_intermediate_tensors = (
277
- make_empty_intermediate_tensors_factory (
278
- ["hidden_states" , "residual" ], config .hidden_size ))
279
-
280
- def forward (
281
- self ,
282
- input_ids : torch .Tensor ,
283
- positions : torch .Tensor ,
284
- intermediate_tensors : Optional [IntermediateTensors ] = None ,
285
- inputs_embeds : Optional [torch .Tensor ] = None ,
286
- _metadata_for_padding : Optional [MetadataForPadding ] = None ,
287
- ) -> Union [torch .Tensor , IntermediateTensors ]:
288
- if get_pp_group ().is_first_rank :
289
- if inputs_embeds is not None :
290
- hidden_states = inputs_embeds
291
- else :
292
- hidden_states = self .get_input_embeddings (input_ids )
293
- residual = None
294
- else :
295
- assert intermediate_tensors is not None
296
- hidden_states = intermediate_tensors ["hidden_states" ]
297
- residual = intermediate_tensors ["residual" ]
298
- for i in range (self .start_layer , self .end_layer ):
299
- layer = self .layers [i ]
300
- hidden_states , residual = layer (
301
- positions ,
302
- hidden_states ,
303
- residual ,
304
- _metadata_for_padding = _metadata_for_padding )
305
- if not get_pp_group ().is_last_rank :
306
- return IntermediateTensors ({
307
- "hidden_states" : hidden_states ,
308
- "residual" : residual
309
- })
310
-
311
- hidden_states , _ = self .norm (hidden_states , residual )
312
-
313
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
314
- hidden_states = _metadata_for_padding .allgather_unpadding_aligned (
315
- hidden_states )
316
-
317
- return hidden_states
318
-
319
37
320
38
class CustomQwen3MoeForCausalLM (Qwen3MoeForCausalLM ):
321
39
packed_modules_mapping = {
@@ -341,8 +59,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
341
59
quant_config = vllm_config .quant_config
342
60
self .config = config
343
61
self .quant_config = quant_config
344
- self .model = CustomQwen3MoeModel (vllm_config = vllm_config ,
345
- prefix = maybe_prefix (prefix , "model" ))
62
+ self .model = Qwen3MoeModel (vllm_config = vllm_config ,
63
+ prefix = maybe_prefix (prefix , "model" ))
346
64
self .lm_head = ParallelLMHead (config .vocab_size ,
347
65
config .hidden_size ,
348
66
quant_config = quant_config ,
@@ -352,8 +70,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
352
70
self .logits_processor = LogitsProcessor (config .vocab_size )
353
71
self .make_empty_intermediate_tensors = (
354
72
self .model .make_empty_intermediate_tensors )
355
-
356
- self .enable_sequence_parallelism = vllm_config .compilation_config .pass_config .enable_sequence_parallelism
357
73
# Set MoE hyperparameters
358
74
self .expert_weights : list [torch .Tensor ] = []
359
75
@@ -382,8 +98,6 @@ def forward(
382
98
intermediate_tensors : Optional [IntermediateTensors ] = None ,
383
99
inputs_embeds : Optional [torch .Tensor ] = None ,
384
100
) -> Union [torch .Tensor , IntermediateTensors ]:
385
- _metadata_for_padding = init_metadata_for_sp (
386
- input_ids , self .enable_sequence_parallelism )
387
101
hidden_states = self .model (input_ids , positions , intermediate_tensors ,
388
- inputs_embeds , _metadata_for_padding )
102
+ inputs_embeds )
389
103
return hidden_states
0 commit comments