17
17
# Adapted from vllm/model_executor/models/qwen3_moe.py
18
18
# This file is a part of the vllm-ascend project.
19
19
20
- from typing import Optional , Union
20
+ from typing import Any , Optional , Union
21
21
22
22
import torch
23
+ import torch_npu
23
24
from torch import nn
24
25
from transformers import PretrainedConfig
26
+ import torch .distributed as dist
27
+ import vllm_ascend .envs as envs_ascend
25
28
from vllm .compilation .decorators import support_torch_compile
26
29
from vllm .config import CacheConfig , CompilationLevel , VllmConfig
27
30
from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
30
33
from vllm .forward_context import get_forward_context
31
34
from vllm .model_executor .layers .fused_moe .layer import FusedMoE
32
35
from vllm .model_executor .layers .layernorm import RMSNorm
33
- from vllm .model_executor .layers .linear import ReplicatedLinear
36
+ from vllm .model_executor .layers .linear import (ReplicatedLinear ,
37
+ RowParallelLinear )
34
38
from vllm .model_executor .layers .logits_processor import LogitsProcessor
35
39
from vllm .model_executor .layers .quantization import QuantizationConfig
36
40
from vllm .model_executor .layers .vocab_parallel_embedding import (
47
51
make_empty_intermediate_tensors_factory , make_layers , maybe_prefix )
48
52
from vllm .sequence import IntermediateTensors
49
53
54
+ from vllm .distributed .communication_op import tensor_model_parallel_all_gather
50
55
from vllm_ascend .ops .fused_moe import AscendFusedMoE
51
56
from vllm_ascend .ops .sequence_parallel import (MetadataForPadding ,
52
57
init_metadata_for_sp )
58
+ init_metadata_for_sp , init_metadata_for_flashcomm2 )
53
59
54
60
55
61
class CustomSparseMoeBlock (Qwen3MoeSparseMoeBlock ):
@@ -125,6 +131,153 @@ def forward(
125
131
return hidden_states
126
132
127
133
134
+ class CustomQwen3MoeMLP (Qwen3MoeMLP ):
135
+
136
+ def __init__ (
137
+ self ,
138
+ hidden_size : int ,
139
+ intermediate_size : int ,
140
+ hidden_act : str ,
141
+ quant_config : Optional [QuantizationConfig ] = None ,
142
+ reduce_results : bool = True ,
143
+ prefix : str = "" ,
144
+ ) -> None :
145
+ super ().__init__ (hidden_size = hidden_size ,
146
+ intermediate_size = intermediate_size ,
147
+ hidden_act = hidden_act ,
148
+ quant_config = quant_config ,
149
+ reduce_results = reduce_results ,
150
+ prefix = prefix )
151
+ self .tp_size = get_tensor_model_parallel_world_size ()
152
+ self .enable_flashcomm2 = envs_ascend .VLLM_ASCEND_ENABLE_FLASHCOMM == 2
153
+ if self .enable_flashcomm2 :
154
+ # if flashcomm2 enabled, replace Linear+AllReduce with All2All+Linear
155
+ self .down_proj = ReplicatedLinear (
156
+ intermediate_size ,
157
+ hidden_size ,
158
+ bias = False ,
159
+ quant_config = quant_config ,
160
+ prefix = f"{ prefix } .down_proj" ,
161
+ )
162
+ else :
163
+ self .down_proj = RowParallelLinear (
164
+ intermediate_size ,
165
+ hidden_size ,
166
+ bias = False ,
167
+ quant_config = quant_config ,
168
+ prefix = f"{ prefix } .down_proj" ,
169
+ )
170
+
171
+ def forward (self , x , _metadata_for_padding = None ):
172
+ #if flashcomm2 enabled, the input of MLP is DP
173
+ #so we need allgather hidden_states and then use TP in gate_up and use DP(by all2all) in down_proj
174
+ if self .enable_flashcomm2 :
175
+ x = tensor_model_parallel_all_gather (x , 0 )
176
+ gate_up , _ = self .gate_up_proj (x )
177
+ x = self .act_fn (gate_up )
178
+ if self .enable_flashcomm2 :
179
+ #Do not need pad input, because the input of mlp is the output of the attn, which is padded
180
+ output = torch .empty (x .shape , dtype = x .dtype , device = x .device )
181
+ dist .all_to_all_single (output ,
182
+ x ,
183
+ group = get_tp_group ().device_group )
184
+ x = output .reshape (self .tp_size , - 1 , output .size (- 1 )) \
185
+ .transpose (0 , 1 ) \
186
+ .reshape (- 1 , output .size (- 1 )* self .tp_size )
187
+ x , _ = self .down_proj (x )
188
+ return x
189
+
190
+
191
+ class CustomQwen3MoeAttention (Qwen3MoeAttention ):
192
+
193
+ def __init__ (
194
+ self ,
195
+ hidden_size : int ,
196
+ num_heads : int ,
197
+ num_kv_heads : int ,
198
+ rope_theta : float = 10000 ,
199
+ rope_scaling : Optional [dict [str , Any ]] = None ,
200
+ max_position_embeddings : int = 8192 ,
201
+ head_dim : Optional [int ] = None ,
202
+ rms_norm_eps : float = 1e-06 ,
203
+ qkv_bias : bool = False ,
204
+ cache_config : Optional [CacheConfig ] = None ,
205
+ quant_config : Optional [QuantizationConfig ] = None ,
206
+ prefix : str = "" ,
207
+ dual_chunk_attention_config : Optional [dict [str , Any ]] = None ,
208
+ ) -> None :
209
+ super ().__init__ (hidden_size = hidden_size ,
210
+ num_heads = num_heads ,
211
+ num_kv_heads = num_kv_heads ,
212
+ rope_theta = rope_theta ,
213
+ rope_scaling = rope_scaling ,
214
+ max_position_embeddings = max_position_embeddings ,
215
+ head_dim = head_dim ,
216
+ rms_norm_eps = rms_norm_eps ,
217
+ qkv_bias = qkv_bias ,
218
+ cache_config = cache_config ,
219
+ quant_config = quant_config ,
220
+ prefix = prefix ,
221
+ dual_chunk_attention_config = dual_chunk_attention_config )
222
+ self .tp_size = get_tensor_model_parallel_world_size ()
223
+ self .enable_flashcomm2 = envs_ascend .VLLM_ASCEND_ENABLE_FLASHCOMM == 2
224
+ if self .enable_flashcomm2 :
225
+ self .o_proj = ReplicatedLinear (
226
+ self .total_num_heads * self .head_dim ,
227
+ hidden_size ,
228
+ bias = False ,
229
+ quant_config = quant_config ,
230
+ prefix = f"{ prefix } .o_proj" ,
231
+ )
232
+ else :
233
+ self .o_proj = RowParallelLinear (self .total_num_heads * self .head_dim ,
234
+ hidden_size ,
235
+ bias = False ,
236
+ quant_config = quant_config ,
237
+ prefix = f"{ prefix } .o_proj" )
238
+
239
+ def attn_output_all_to_all (self ,
240
+ attn_output : torch .Tensor ,
241
+ _metadata_for_padding : Optional [MetadataForPadding ] = None ) -> torch .Tensor :
242
+ assert _metadata_for_padding is not None , "Metadata for padding is required for FlashComm2."
243
+ # pad input because AllGather requires token_num to be divisible by tp_size
244
+ attn_output = _metadata_for_padding .padding_full (attn_output )
245
+ output = torch .empty (attn_output .shape ,
246
+ dtype = attn_output .dtype ,
247
+ device = attn_output .device )
248
+ dist .all_to_all_single (output ,
249
+ attn_output ,
250
+ group = get_tp_group ().device_group )
251
+ attn_output = output .reshape (self .tp_size , - 1 , output .size (- 1 )) \
252
+ .transpose (0 , 1 ) \
253
+ .reshape (- 1 , output .size (- 1 )* self .tp_size )
254
+ return attn_output
255
+
256
+ def forward (
257
+ self ,
258
+ positions : torch .Tensor ,
259
+ hidden_states : torch .Tensor ,
260
+ _metadata_for_padding : Optional [MetadataForPadding ] = None ) -> torch .Tensor :
261
+ qkv , _ = self .qkv_proj (hidden_states )
262
+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
263
+ # Add qk-norm
264
+ q_by_head = q .view (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim ,
265
+ self .head_dim )
266
+ q_by_head = self .q_norm (q_by_head )
267
+ q = q_by_head .view (q .shape )
268
+
269
+ k_by_head = k .view (* k .shape [:- 1 ], k .shape [- 1 ] // self .head_dim ,
270
+ self .head_dim )
271
+ k_by_head = self .k_norm (k_by_head )
272
+ k = k_by_head .view (k .shape )
273
+ q , k = self .rotary_emb (positions , q , k )
274
+ attn_output = self .attn (q , k , v )
275
+ if self .enable_flashcomm2 :
276
+ attn_output = self .attn_output_all_to_all (attn_output , _metadata_for_padding )
277
+ output , _ = self .o_proj (attn_output )
278
+ return output
279
+
280
+
128
281
class CustomQwen3MoeDecoderLayer (Qwen3MoeDecoderLayer ):
129
282
130
283
def __init__ (
@@ -142,7 +295,7 @@ def __init__(
142
295
rope_scaling = getattr (config , "rope_scaling" , None )
143
296
max_position_embeddings = getattr (config , "max_position_embeddings" ,
144
297
8192 )
145
- self .self_attn = Qwen3MoeAttention (
298
+ self .self_attn = CustomQwen3MoeAttention (
146
299
hidden_size = self .hidden_size ,
147
300
num_heads = config .num_attention_heads ,
148
301
num_kv_heads = config .num_key_value_heads ,
@@ -178,7 +331,7 @@ def __init__(
178
331
quant_config = quant_config ,
179
332
prefix = f"{ prefix } .mlp" )
180
333
else :
181
- self .mlp = Qwen3MoeMLP (hidden_size = config .hidden_size ,
334
+ self .mlp = CustomQwen3MoeMLP (hidden_size = config .hidden_size ,
182
335
intermediate_size = config .intermediate_size ,
183
336
hidden_act = config .hidden_act ,
184
337
quant_config = quant_config ,
@@ -191,6 +344,7 @@ def __init__(
191
344
self .enable_sequence_parallelism = (
192
345
vllm_config .compilation_config .pass_config .
193
346
enable_sequence_parallelism if vllm_config is not None else False )
347
+ self .enable_flashcomm2 = envs_ascend .VLLM_ASCEND_ENABLE_FLASHCOMM == 2
194
348
195
349
def forward (
196
350
self ,
@@ -201,34 +355,37 @@ def forward(
201
355
) -> torch .Tensor :
202
356
203
357
# To prevent precision issues during the decoder phase when only prefilling enables SP
204
- if not self .enable_sequence_parallelism :
205
- self .self_attn .o_proj .reduce_results = True
206
- else :
207
- self .self_attn .o_proj .reduce_results = not _metadata_for_padding .not_dummy_and_is_prefill if _metadata_for_padding is not None else True
358
+ if not self .enable_flashcomm2 :
359
+ if not self .enable_sequence_parallelism :
360
+ self .self_attn .o_proj .reduce_results = True
361
+ else :
362
+ self .self_attn .o_proj .reduce_results = not _metadata_for_padding .not_dummy_and_is_prefill if _metadata_for_padding is not None else True
208
363
209
364
# Self Attention
210
365
if residual is None :
211
366
residual = hidden_states
212
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
367
+ if _metadata_for_padding and ( _metadata_for_padding .not_dummy_and_is_prefill or self . enable_flashcomm2 ) :
213
368
residual = _metadata_for_padding .padding_slice (residual )
214
369
215
370
hidden_states = self .input_layernorm (hidden_states )
216
371
else :
217
372
hidden_states , residual = self .input_layernorm (
218
373
hidden_states , residual )
219
374
220
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
375
+ if _metadata_for_padding and ( _metadata_for_padding .not_dummy_and_is_prefill or self . enable_flashcomm2 ) :
221
376
hidden_states = _metadata_for_padding .allgather_unpadding_aligned (
222
377
hidden_states )
223
378
224
379
hidden_states = self .self_attn (
225
380
positions = positions ,
226
381
hidden_states = hidden_states ,
382
+ _metadata_for_padding = _metadata_for_padding
227
383
)
228
384
229
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
230
- hidden_states = _metadata_for_padding .padding_aligned_reduce_scatter (
231
- hidden_states )
385
+ if not self .enable_flashcomm2 :
386
+ if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
387
+ hidden_states = _metadata_for_padding .padding_aligned_reduce_scatter (
388
+ hidden_states )
232
389
233
390
# Fully Connected
234
391
hidden_states , residual = self .post_attention_layernorm (
@@ -276,6 +433,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
276
433
self .make_empty_intermediate_tensors = (
277
434
make_empty_intermediate_tensors_factory (
278
435
["hidden_states" , "residual" ], config .hidden_size ))
436
+ self .enable_flashcomm2 = envs_ascend .VLLM_ASCEND_ENABLE_FLASHCOMM == 2
279
437
280
438
def forward (
281
439
self ,
@@ -310,7 +468,7 @@ def forward(
310
468
311
469
hidden_states , _ = self .norm (hidden_states , residual )
312
470
313
- if _metadata_for_padding and _metadata_for_padding .not_dummy_and_is_prefill :
471
+ if _metadata_for_padding and ( _metadata_for_padding .not_dummy_and_is_prefill or self . enable_flashcomm2 ) :
314
472
hidden_states = _metadata_for_padding .allgather_unpadding_aligned (
315
473
hidden_states )
316
474
@@ -354,6 +512,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
354
512
self .model .make_empty_intermediate_tensors )
355
513
356
514
self .enable_sequence_parallelism = vllm_config .compilation_config .pass_config .enable_sequence_parallelism
515
+ self .enable_flashcomm2 = envs_ascend .VLLM_ASCEND_ENABLE_FLASHCOMM == 2
357
516
# Set MoE hyperparameters
358
517
self .expert_weights : list [torch .Tensor ] = []
359
518
@@ -382,8 +541,13 @@ def forward(
382
541
intermediate_tensors : Optional [IntermediateTensors ] = None ,
383
542
inputs_embeds : Optional [torch .Tensor ] = None ,
384
543
) -> Union [torch .Tensor , IntermediateTensors ]:
385
- _metadata_for_padding = init_metadata_for_sp (
386
- input_ids , self .enable_sequence_parallelism )
544
+ if self .enable_flashcomm2 :
545
+ if self .enable_sequence_parallelism :
546
+ raise ValueError (f"Sequence parallelism and FlashComm2 cannot be enabled simultaneously." )
547
+ _metadata_for_padding = init_metadata_for_flashcomm2 (input_ids )
548
+ else :
549
+ _metadata_for_padding = init_metadata_for_sp (
550
+ input_ids , self .enable_sequence_parallelism )
387
551
hidden_states = self .model (input_ids , positions , intermediate_tensors ,
388
552
inputs_embeds , _metadata_for_padding )
389
553
return hidden_states
0 commit comments