2
2
from typing import Optional , Union
3
3
4
4
import torch
5
+ import torch .distributed as dist
6
+ import torch .nn .functional as F
5
7
from torch import nn
6
8
from transformers import Qwen3Config
7
9
from vllm .attention import AttentionType
8
10
from vllm .compilation .decorators import support_torch_compile
9
11
from vllm .config import CacheConfig , VllmConfig
10
- from vllm .distributed import get_pp_group
11
12
from vllm .model_executor .layers .layernorm import RMSNorm
13
+ from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
14
+ get_tensor_model_parallel_world_size ,
15
+ tensor_model_parallel_all_gather )
16
+ from vllm .model_executor .layers .linear import (ReplicatedLinear ,
17
+ RowParallelLinear )
12
18
from vllm .model_executor .layers .logits_processor import LogitsProcessor
13
19
from vllm .model_executor .layers .quantization import QuantizationConfig
14
20
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
15
21
from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
22
+ from vllm .model_executor .models .qwen2 import Qwen2MLP as Qwen3MLP
16
23
from vllm .model_executor .models .qwen2 import Qwen2Model
17
24
from vllm .model_executor .models .qwen3 import Qwen3Attention , Qwen3MLP
18
25
from vllm .model_executor .models .utils import (AutoWeightsLoader ,
19
26
PPMissingLayer , maybe_prefix )
20
27
from vllm .model_executor .sampling_metadata import SamplingMetadata
21
28
from vllm .sequence import IntermediateTensors
22
29
30
+ from vllm_ascend import envs
23
31
from vllm_ascend .ops .layernorm import AddRMSNormW8A8Quant
24
32
25
33
34
+ def pad (tensor , x ):
35
+ length = tensor .size (0 )
36
+ pad_size = (x - (length % x )) % x
37
+ if pad_size > 0 :
38
+ return F .pad (tensor , (0 , 0 , 0 , pad_size )), pad_size
39
+ return tensor , pad_size
40
+
41
+
42
+ def unpad (tensor , pad_size ):
43
+ if pad_size > 0 :
44
+ return tensor [:- pad_size , :]
45
+ return tensor
46
+
47
+
48
+ class CustomQwen3MLP (Qwen3MLP ):
49
+
50
+ def __init__ (
51
+ self ,
52
+ hidden_size : int ,
53
+ intermediate_size : int ,
54
+ hidden_act : str ,
55
+ quant_config : Optional [QuantizationConfig ] = None ,
56
+ prefix : str = "" ,
57
+ ) -> None :
58
+ super ().__init__ (hidden_size = hidden_size ,
59
+ intermediate_size = intermediate_size ,
60
+ hidden_act = hidden_act ,
61
+ quant_config = quant_config ,
62
+ prefix = prefix )
63
+ self .tp_size = get_tensor_model_parallel_world_size ()
64
+ self .tp_rank = get_tensor_model_parallel_rank ()
65
+ self .enable_fc = envs .VLLM_ENABLE_FC
66
+ if self .enable_fc :
67
+ # if flashcomm2 enabled, replace Linear+AllReduce with All2All+Linear
68
+ self .down_proj = ReplicatedLinear (
69
+ intermediate_size ,
70
+ hidden_size ,
71
+ bias = False ,
72
+ quant_config = quant_config ,
73
+ prefix = f"{ prefix } .down_proj" ,
74
+ )
75
+ else :
76
+ self .down_proj = RowParallelLinear (
77
+ intermediate_size ,
78
+ hidden_size ,
79
+ bias = False ,
80
+ quant_config = quant_config ,
81
+ prefix = f"{ prefix } .down_proj" ,
82
+ )
83
+
84
+ def forward (self , x ):
85
+ gate_up , _ = self .gate_up_proj (x )
86
+ x = self .act_fn (gate_up )
87
+ pad_size = 0
88
+ if self .enable_fc :
89
+ # pad input because AllGather requires token_num to be divisible by tp_size
90
+ x , pad_size = pad (x , self .tp_size )
91
+ output = torch .empty (x .shape , dtype = x .dtype , device = x .device )
92
+ dist .all_to_all_single (output , x )
93
+ x = output .reshape (self .tp_size , - 1 , output .size (- 1 )) \
94
+ .transpose (0 , 1 ) \
95
+ .reshape (- 1 , output .size (- 1 )* self .tp_size )
96
+ x , _ = self .down_proj (x )
97
+ return x , pad_size
98
+
99
+
26
100
class CustomQwen3Attention (Qwen3Attention ):
27
101
28
102
def __init__ (self ,
@@ -52,6 +126,25 @@ def __init__(self,
52
126
rope_scaling = rope_scaling ,
53
127
prefix = prefix ,
54
128
attn_type = attn_type )
129
+ self .tp_size = get_tensor_model_parallel_world_size ()
130
+ self .tp_rank = get_tensor_model_parallel_rank ()
131
+ self .enable_fc = envs .VLLM_ENABLE_FC
132
+ if self .enable_fc :
133
+ self .o_proj = ReplicatedLinear (
134
+ self .total_num_heads * self .head_dim ,
135
+ hidden_size ,
136
+ bias = False ,
137
+ quant_config = quant_config ,
138
+ prefix = f"{ prefix } .o_proj" ,
139
+ )
140
+ else :
141
+ self .o_proj = RowParallelLinear (
142
+ self .total_num_heads * self .head_dim ,
143
+ hidden_size ,
144
+ bias = False ,
145
+ quant_config = quant_config ,
146
+ prefix = f"{ prefix } .o_proj" ,
147
+ )
55
148
56
149
def forward (
57
150
self ,
@@ -78,8 +171,19 @@ def forward(
78
171
sin = sin ,
79
172
skip_index_select = True )
80
173
attn_output = self .attn (q , k , v )
174
+ pad_size = 0
175
+ if self .enable_fc :
176
+ # pad input because AllGather requires token_num to be divisible by tp_size
177
+ attn_output , pad_size = pad (attn_output , self .tp_size )
178
+ output = torch .empty (attn_output .shape ,
179
+ dtype = attn_output .dtype ,
180
+ device = attn_output .device )
181
+ dist .all_to_all_single (output , attn_output )
182
+ attn_output = output .reshape (self .tp_size , - 1 , output .size (- 1 )) \
183
+ .transpose (0 , 1 ) \
184
+ .reshape (- 1 , output .size (- 1 )* self .tp_size )
81
185
output , _ = self .o_proj (attn_output )
82
- return output
186
+ return output , pad_size
83
187
84
188
85
189
class CustomQwen3DecoderLayer (nn .Module ):
@@ -93,6 +197,9 @@ def __init__(
93
197
) -> None :
94
198
super ().__init__ ()
95
199
self .hidden_size = config .hidden_size
200
+ self .tp_size = get_tensor_model_parallel_world_size ()
201
+ self .tp_rank = get_tensor_model_parallel_rank ()
202
+ self .enable_fc = envs .VLLM_ENABLE_FC
96
203
# Requires transformers > 4.32.0
97
204
rope_theta = getattr (config , "rope_theta" , 1000000 )
98
205
rope_scaling = getattr (config , "rope_scaling" , None )
@@ -121,7 +228,7 @@ def __init__(
121
228
prefix = f"{ prefix } .self_attn" ,
122
229
attn_type = attn_type ,
123
230
)
124
- self .mlp = Qwen3MLP (
231
+ self .mlp = CustomQwen3MLP (
125
232
hidden_size = self .hidden_size ,
126
233
intermediate_size = config .intermediate_size ,
127
234
hidden_act = config .hidden_act ,
@@ -185,6 +292,57 @@ def forward(
185
292
hidden_states = self .mlp (hidden_states )
186
293
return hidden_states , residual
187
294
295
+ def pre_attention_process (self , hidden_states , residual , pad_size = 0 ):
296
+ hidden_states , residual = self .input_layernorm (hidden_states , residual )
297
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
298
+ hidden_states = unpad (hidden_states , pad_size )
299
+ return hidden_states , residual
300
+
301
+ def pre_mlp_process (self , hidden_states , residual , pad_size = 0 ):
302
+ token_num = hidden_states .size (0 )
303
+ if token_num != residual .size (0 ):
304
+ if pad_size > 0 :
305
+ residual = F .pad (residual , (0 , 0 , 0 , pad_size ))
306
+ split_size_list = [token_num ] * self .tp_size
307
+ residual = torch .split (residual , split_size_list )[self .tp_rank ]
308
+
309
+ hidden_states , residual = self .post_attention_layernorm (
310
+ hidden_states , residual )
311
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
312
+ hidden_states = unpad (hidden_states , pad_size )
313
+ return hidden_states , residual
314
+
315
+ def forward (self ,
316
+ positions : torch .Tensor ,
317
+ hidden_states : torch .Tensor ,
318
+ residual : Optional [torch .Tensor ],
319
+ pad_size : int = 0 ) -> tuple [torch .Tensor , torch .Tensor , int ]:
320
+ # Self Attention
321
+ if residual is None :
322
+ residual = hidden_states
323
+ hidden_states = self .input_layernorm (hidden_states )
324
+ else :
325
+ if self .enable_fc :
326
+ hidden_states , residual = self .pre_attention_process (
327
+ hidden_states , residual , pad_size )
328
+ else :
329
+ hidden_states , residual = self .input_layernorm (
330
+ hidden_states , residual )
331
+ hidden_states , pad_size = self .self_attn (
332
+ positions = positions ,
333
+ hidden_states = hidden_states ,
334
+ )
335
+
336
+ # Fully Connected
337
+ if self .enable_fc :
338
+ hidden_states , residual = self .pre_mlp_process (
339
+ hidden_states , residual , pad_size )
340
+ else :
341
+ hidden_states , residual = self .post_attention_layernorm (
342
+ hidden_states , residual )
343
+ hidden_states , pad_size = self .mlp (hidden_states )
344
+ return hidden_states , residual , pad_size
345
+
188
346
189
347
ALL_DECODER_LAYER_TYPES = {
190
348
"attention" : CustomQwen3DecoderLayer ,
@@ -207,6 +365,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
207
365
prefix = prefix ,
208
366
decoder_layer_type = CustomQwen3DecoderLayer )
209
367
self .cos_sin_cache = self .layers [0 ].self_attn .rotary_emb .cos_sin_cache
368
+ self .tp_size = get_tensor_model_parallel_world_size ()
369
+ self .tp_rank = get_tensor_model_parallel_rank ()
370
+ self .enable_fc = envs .VLLM_ENABLE_FC
210
371
211
372
def forward (
212
373
self ,
@@ -235,14 +396,17 @@ def forward(
235
396
cos , sin = cos .view (1 , - 1 , 1 , last_dim ).contiguous (), sin .view (
236
397
1 , - 1 , 1 , last_dim ).contiguous ()
237
398
399
+ pad_size = 0
238
400
for layer in self .layers [self .start_layer :self .end_layer ]:
239
- hidden_states , residual = layer (
240
- positions ,
241
- cos ,
242
- sin ,
243
- hidden_states ,
244
- residual ,
245
- )
401
+ hidden_states , residual , pad_size = layer (positions , hidden_states ,
402
+ residual , pad_size )
403
+ if self .enable_fc :
404
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
405
+ residual = tensor_model_parallel_all_gather (residual , 0 )
406
+ if pad_size > 0 :
407
+ hidden_states = hidden_states [:- pad_size ]
408
+ residual = residual [:- pad_size ]
409
+
246
410
if not get_pp_group ().is_last_rank :
247
411
return IntermediateTensors ({
248
412
"hidden_states" : hidden_states ,
0 commit comments