2
2
from typing import Optional , Union
3
3
4
4
import torch
5
+ import torch .distributed as dist
6
+ import torch .nn .functional as F
5
7
from torch import nn
6
8
from transformers import Qwen3Config
7
9
from vllm .attention import AttentionType
8
10
from vllm .compilation .decorators import support_torch_compile
9
11
from vllm .config import CacheConfig , VllmConfig
10
- from vllm .distributed import get_pp_group
12
+ from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
13
+ get_tensor_model_parallel_world_size ,
14
+ tensor_model_parallel_all_gather )
11
15
from vllm .model_executor .layers .layernorm import RMSNorm
16
+ from vllm .model_executor .layers .linear import (ReplicatedLinear ,
17
+ RowParallelLinear )
12
18
from vllm .model_executor .layers .logits_processor import LogitsProcessor
13
19
from vllm .model_executor .layers .quantization import QuantizationConfig
14
20
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
20
26
from vllm .model_executor .sampling_metadata import SamplingMetadata
21
27
from vllm .sequence import IntermediateTensors
22
28
29
+ from vllm_ascend import envs
23
30
from vllm_ascend .ops .layernorm import AddRMSNormW8A8Quant
24
31
25
32
33
+ def pad (tensor , x ):
34
+ length = tensor .size (0 )
35
+ pad_size = (x - (length % x )) % x
36
+ if pad_size > 0 :
37
+ return F .pad (tensor , (0 , 0 , 0 , pad_size )), pad_size
38
+ return tensor , pad_size
39
+
40
+
41
+ def unpad (tensor , pad_size ):
42
+ if pad_size > 0 :
43
+ return tensor [:- pad_size , :]
44
+ return tensor
45
+
46
+
47
+ class CustomQwen3MLP (Qwen3MLP ):
48
+
49
+ def __init__ (
50
+ self ,
51
+ hidden_size : int ,
52
+ intermediate_size : int ,
53
+ hidden_act : str ,
54
+ quant_config : Optional [QuantizationConfig ] = None ,
55
+ prefix : str = "" ,
56
+ ) -> None :
57
+ super ().__init__ (hidden_size = hidden_size ,
58
+ intermediate_size = intermediate_size ,
59
+ hidden_act = hidden_act ,
60
+ quant_config = quant_config ,
61
+ prefix = prefix )
62
+ self .tp_size = get_tensor_model_parallel_world_size ()
63
+ self .tp_rank = get_tensor_model_parallel_rank ()
64
+ self .enable_fc = envs .VLLM_ENABLE_FC
65
+ if self .enable_fc :
66
+ # if flashcomm2 enabled, replace Linear+AllReduce with All2All+Linear
67
+ self .down_proj = ReplicatedLinear (
68
+ intermediate_size ,
69
+ hidden_size ,
70
+ bias = False ,
71
+ quant_config = quant_config ,
72
+ prefix = f"{ prefix } .down_proj" ,
73
+ )
74
+ else :
75
+ self .down_proj = RowParallelLinear (
76
+ intermediate_size ,
77
+ hidden_size ,
78
+ bias = False ,
79
+ quant_config = quant_config ,
80
+ prefix = f"{ prefix } .down_proj" ,
81
+ )
82
+
83
+ def forward (self , x ):
84
+ gate_up , _ = self .gate_up_proj (x )
85
+ x = self .act_fn (gate_up )
86
+ pad_size = 0
87
+ if self .enable_fc :
88
+ # pad input because AllGather requires token_num to be divisible by tp_size
89
+ x , pad_size = pad (x , self .tp_size )
90
+ output = torch .empty (x .shape , dtype = x .dtype , device = x .device )
91
+ dist .all_to_all_single (output , x )
92
+ x = output .reshape (self .tp_size , - 1 , output .size (- 1 )) \
93
+ .transpose (0 , 1 ) \
94
+ .reshape (- 1 , output .size (- 1 )* self .tp_size )
95
+ x , _ = self .down_proj (x )
96
+ return x , pad_size
97
+
98
+
26
99
class CustomQwen3Attention (Qwen3Attention ):
27
100
28
101
def __init__ (self ,
@@ -52,13 +125,32 @@ def __init__(self,
52
125
rope_scaling = rope_scaling ,
53
126
prefix = prefix ,
54
127
attn_type = attn_type )
128
+ self .tp_size = get_tensor_model_parallel_world_size ()
129
+ self .tp_rank = get_tensor_model_parallel_rank ()
130
+ self .enable_fc = envs .VLLM_ENABLE_FC
131
+ if self .enable_fc :
132
+ self .o_proj = ReplicatedLinear (
133
+ self .total_num_heads * self .head_dim ,
134
+ hidden_size ,
135
+ bias = False ,
136
+ quant_config = quant_config ,
137
+ prefix = f"{ prefix } .o_proj" ,
138
+ )
139
+ else :
140
+ self .o_proj = RowParallelLinear (
141
+ self .total_num_heads * self .head_dim ,
142
+ hidden_size ,
143
+ bias = False ,
144
+ quant_config = quant_config ,
145
+ prefix = f"{ prefix } .o_proj" ,
146
+ )
55
147
56
148
def forward (
57
149
self ,
58
150
positions : torch .Tensor ,
151
+ hidden_states : torch .Tensor ,
59
152
cos : torch .Tensor ,
60
153
sin : torch .Tensor ,
61
- hidden_states : torch .Tensor ,
62
154
) -> torch .Tensor :
63
155
qkv , _ = self .qkv_proj (hidden_states )
64
156
q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
@@ -78,8 +170,19 @@ def forward(
78
170
sin = sin ,
79
171
skip_index_select = True )
80
172
attn_output = self .attn (q , k , v )
173
+ pad_size = 0
174
+ if self .enable_fc :
175
+ # pad input because AllGather requires token_num to be divisible by tp_size
176
+ attn_output , pad_size = pad (attn_output , self .tp_size )
177
+ output = torch .empty (attn_output .shape ,
178
+ dtype = attn_output .dtype ,
179
+ device = attn_output .device )
180
+ dist .all_to_all_single (output , attn_output )
181
+ attn_output = output .reshape (self .tp_size , - 1 , output .size (- 1 )) \
182
+ .transpose (0 , 1 ) \
183
+ .reshape (- 1 , output .size (- 1 )* self .tp_size )
81
184
output , _ = self .o_proj (attn_output )
82
- return output
185
+ return output , pad_size
83
186
84
187
85
188
class CustomQwen3DecoderLayer (nn .Module ):
@@ -93,6 +196,9 @@ def __init__(
93
196
) -> None :
94
197
super ().__init__ ()
95
198
self .hidden_size = config .hidden_size
199
+ self .tp_size = get_tensor_model_parallel_world_size ()
200
+ self .tp_rank = get_tensor_model_parallel_rank ()
201
+ self .enable_fc = envs .VLLM_ENABLE_FC
96
202
# Requires transformers > 4.32.0
97
203
rope_theta = getattr (config , "rope_theta" , 1000000 )
98
204
rope_scaling = getattr (config , "rope_scaling" , None )
@@ -121,7 +227,7 @@ def __init__(
121
227
prefix = f"{ prefix } .self_attn" ,
122
228
attn_type = attn_type ,
123
229
)
124
- self .mlp = Qwen3MLP (
230
+ self .mlp = CustomQwen3MLP (
125
231
hidden_size = self .hidden_size ,
126
232
intermediate_size = config .intermediate_size ,
127
233
hidden_act = config .hidden_act ,
@@ -159,31 +265,58 @@ def __init__(
159
265
self .post_attention_layernorm = RMSNorm (
160
266
config .hidden_size , eps = config .rms_norm_eps )
161
267
162
- def forward (
163
- self ,
164
- positions : torch .Tensor ,
165
- cos : torch .Tensor ,
166
- sin : torch .Tensor ,
167
- hidden_states : torch .Tensor ,
168
- residual : Optional [torch .Tensor ],
169
- ) -> tuple [torch .Tensor , torch .Tensor ]:
268
+ def pre_attention_process (self , hidden_states , residual , pad_size = 0 ):
269
+ hidden_states , residual = self .input_layernorm (hidden_states , residual )
270
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
271
+ hidden_states = unpad (hidden_states , pad_size )
272
+ return hidden_states , residual
273
+
274
+ def pre_mlp_process (self , hidden_states , residual , pad_size = 0 ):
275
+ token_num = hidden_states .size (0 )
276
+ if token_num != residual .size (0 ):
277
+ if pad_size > 0 :
278
+ residual = F .pad (residual , (0 , 0 , 0 , pad_size ))
279
+ split_size_list = [token_num ] * self .tp_size
280
+ residual = torch .split (residual , split_size_list )[self .tp_rank ]
281
+
282
+ hidden_states , residual = self .post_attention_layernorm (
283
+ hidden_states , residual )
284
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
285
+ hidden_states = unpad (hidden_states , pad_size )
286
+ return hidden_states , residual
287
+
288
+ def forward (self ,
289
+ positions : torch .Tensor ,
290
+ hidden_states : torch .Tensor ,
291
+ residual : Optional [torch .Tensor ],
292
+ cos : torch .Tensor ,
293
+ sin : torch .Tensor ,
294
+ pad_size : int = 0 ) -> tuple [torch .Tensor , torch .Tensor , int ]:
170
295
# Self Attention
171
296
if residual is None :
172
297
residual = hidden_states
173
298
hidden_states = self .input_layernorm (hidden_states )
174
299
else :
175
- hidden_states , residual = self .input_layernorm (
300
+ if self .enable_fc :
301
+ hidden_states , residual = self .pre_attention_process (
302
+ hidden_states , residual , pad_size )
303
+ else :
304
+ hidden_states , residual = self .input_layernorm (
305
+ hidden_states , residual )
306
+ hidden_states , pad_size = self .self_attn (positions = positions ,
307
+ hidden_states = hidden_states ,
308
+ cos = cos ,
309
+ sin = sin )
310
+
311
+ # Fully Connected
312
+ if self .enable_fc :
313
+ hidden_states , residual = self .pre_mlp_process (
314
+ hidden_states , residual , pad_size )
315
+ else :
316
+ hidden_states , residual = self .post_attention_layernorm (
176
317
hidden_states , residual )
177
- hidden_states = self .self_attn (
178
- positions = positions ,
179
- cos = cos ,
180
- sin = sin ,
181
- hidden_states = hidden_states ,
182
- )
183
- hidden_states , residual = self .post_attention_layernorm (
184
- hidden_states , residual )
185
- hidden_states = self .mlp (hidden_states )
186
- return hidden_states , residual
318
+ hidden_states , pad_size = self .mlp (hidden_states )
319
+ return hidden_states , residual , pad_size
187
320
188
321
189
322
ALL_DECODER_LAYER_TYPES = {
@@ -207,6 +340,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
207
340
prefix = prefix ,
208
341
decoder_layer_type = CustomQwen3DecoderLayer )
209
342
self .cos_sin_cache = self .layers [0 ].self_attn .rotary_emb .cos_sin_cache
343
+ self .tp_size = get_tensor_model_parallel_world_size ()
344
+ self .tp_rank = get_tensor_model_parallel_rank ()
345
+ self .enable_fc = envs .VLLM_ENABLE_FC
210
346
211
347
def forward (
212
348
self ,
@@ -235,14 +371,18 @@ def forward(
235
371
cos , sin = cos .view (1 , - 1 , 1 , last_dim ).contiguous (), sin .view (
236
372
1 , - 1 , 1 , last_dim ).contiguous ()
237
373
374
+ pad_size = 0
238
375
for layer in self .layers [self .start_layer :self .end_layer ]:
239
- hidden_states , residual = layer (
240
- positions ,
241
- cos ,
242
- sin ,
243
- hidden_states ,
244
- residual ,
245
- )
376
+ hidden_states , residual , pad_size = layer (positions , hidden_states ,
377
+ residual , cos , sin ,
378
+ pad_size )
379
+ if self .enable_fc :
380
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
381
+ residual = tensor_model_parallel_all_gather (residual , 0 )
382
+ if pad_size > 0 :
383
+ hidden_states = hidden_states [:- pad_size ]
384
+ residual = residual [:- pad_size ]
385
+
246
386
if not get_pp_group ().is_last_rank :
247
387
return IntermediateTensors ({
248
388
"hidden_states" : hidden_states ,
0 commit comments