2
2
from typing import Optional , Union
3
3
4
4
import torch
5
+ import torch .distributed as dist
6
+ import torch .nn .functional as F
5
7
from torch import nn
6
8
from transformers import Qwen3Config
7
9
from vllm .attention import AttentionType
8
10
from vllm .compilation .decorators import support_torch_compile
9
11
from vllm .config import CacheConfig , VllmConfig
10
- from vllm .distributed import get_pp_group
12
+ from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
13
+ get_tensor_model_parallel_world_size ,
14
+ get_tp_group , tensor_model_parallel_all_gather )
11
15
from vllm .model_executor .layers .layernorm import RMSNorm
16
+ from vllm .model_executor .layers .linear import (ReplicatedLinear ,
17
+ RowParallelLinear )
12
18
from vllm .model_executor .layers .logits_processor import LogitsProcessor
13
19
from vllm .model_executor .layers .quantization import QuantizationConfig
14
20
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
20
26
from vllm .model_executor .sampling_metadata import SamplingMetadata
21
27
from vllm .sequence import IntermediateTensors
22
28
29
+ from vllm_ascend import envs
23
30
from vllm_ascend .ops .layernorm import AddRMSNormW8A8Quant
24
31
25
32
33
+ def pad (tensor , x ):
34
+ length = tensor .size (0 )
35
+ pad_size = (x - (length % x )) % x
36
+ if pad_size > 0 :
37
+ return F .pad (tensor , (0 , 0 , 0 , pad_size )), pad_size
38
+ return tensor , pad_size
39
+
40
+
41
+ def unpad (tensor , pad_size ):
42
+ if pad_size > 0 :
43
+ return tensor [:- pad_size , :]
44
+ return tensor
45
+
46
+
47
+ class CustomQwen3MLP (Qwen3MLP ):
48
+
49
+ def __init__ (
50
+ self ,
51
+ hidden_size : int ,
52
+ intermediate_size : int ,
53
+ hidden_act : str ,
54
+ quant_config : Optional [QuantizationConfig ] = None ,
55
+ prefix : str = "" ,
56
+ ) -> None :
57
+ super ().__init__ (hidden_size = hidden_size ,
58
+ intermediate_size = intermediate_size ,
59
+ hidden_act = hidden_act ,
60
+ quant_config = quant_config ,
61
+ prefix = prefix )
62
+ self .tp_size = get_tensor_model_parallel_world_size ()
63
+ self .tp_rank = get_tensor_model_parallel_rank ()
64
+ self .enable_fc = envs .VLLM_ASCEND_ENABLE_FLASHCOMM
65
+ if self .enable_fc == 2 :
66
+ # if flashcomm2 enabled, replace Linear+AllReduce with All2All+Linear
67
+ self .down_proj = ReplicatedLinear (
68
+ intermediate_size ,
69
+ hidden_size ,
70
+ bias = False ,
71
+ quant_config = quant_config ,
72
+ prefix = f"{ prefix } .down_proj" ,
73
+ )
74
+ else :
75
+ self .down_proj = RowParallelLinear (
76
+ intermediate_size ,
77
+ hidden_size ,
78
+ bias = False ,
79
+ quant_config = quant_config ,
80
+ prefix = f"{ prefix } .down_proj" ,
81
+ )
82
+
83
+ def forward (self , x ):
84
+ gate_up , _ = self .gate_up_proj (x )
85
+ x = self .act_fn (gate_up )
86
+ pad_size = 0
87
+ if self .enable_fc == 2 :
88
+ # pad input because AllGather requires token_num to be divisible by tp_size
89
+ x , pad_size = pad (x , self .tp_size )
90
+ output = torch .empty (x .shape , dtype = x .dtype , device = x .device )
91
+ dist .all_to_all_single (output ,
92
+ x ,
93
+ group = get_tp_group ().device_group )
94
+ x = output .reshape (self .tp_size , - 1 , output .size (- 1 )) \
95
+ .transpose (0 , 1 ) \
96
+ .reshape (- 1 , output .size (- 1 )* self .tp_size )
97
+ x , _ = self .down_proj (x )
98
+ return x , pad_size
99
+
100
+
26
101
class CustomQwen3Attention (Qwen3Attention ):
27
102
28
103
def __init__ (self ,
@@ -52,13 +127,32 @@ def __init__(self,
52
127
rope_scaling = rope_scaling ,
53
128
prefix = prefix ,
54
129
attn_type = attn_type )
130
+ self .tp_size = get_tensor_model_parallel_world_size ()
131
+ self .tp_rank = get_tensor_model_parallel_rank ()
132
+ self .enable_fc = envs .VLLM_ASCEND_ENABLE_FLASHCOMM
133
+ if self .enable_fc == 2 :
134
+ self .o_proj = ReplicatedLinear (
135
+ self .total_num_heads * self .head_dim ,
136
+ hidden_size ,
137
+ bias = False ,
138
+ quant_config = quant_config ,
139
+ prefix = f"{ prefix } .o_proj" ,
140
+ )
141
+ else :
142
+ self .o_proj = RowParallelLinear (
143
+ self .total_num_heads * self .head_dim ,
144
+ hidden_size ,
145
+ bias = False ,
146
+ quant_config = quant_config ,
147
+ prefix = f"{ prefix } .o_proj" ,
148
+ )
55
149
56
150
def forward (
57
151
self ,
58
152
positions : torch .Tensor ,
153
+ hidden_states : torch .Tensor ,
59
154
cos : torch .Tensor ,
60
155
sin : torch .Tensor ,
61
- hidden_states : torch .Tensor ,
62
156
) -> torch .Tensor :
63
157
qkv , _ = self .qkv_proj (hidden_states )
64
158
q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
@@ -78,8 +172,21 @@ def forward(
78
172
sin = sin ,
79
173
skip_index_select = True )
80
174
attn_output = self .attn (q , k , v )
175
+ pad_size = 0
176
+ if self .enable_fc == 2 :
177
+ # pad input because AllGather requires token_num to be divisible by tp_size
178
+ attn_output , pad_size = pad (attn_output , self .tp_size )
179
+ output = torch .empty (attn_output .shape ,
180
+ dtype = attn_output .dtype ,
181
+ device = attn_output .device )
182
+ dist .all_to_all_single (output ,
183
+ attn_output ,
184
+ group = get_tp_group ().device_group )
185
+ attn_output = output .reshape (self .tp_size , - 1 , output .size (- 1 )) \
186
+ .transpose (0 , 1 ) \
187
+ .reshape (- 1 , output .size (- 1 )* self .tp_size )
81
188
output , _ = self .o_proj (attn_output )
82
- return output
189
+ return output , pad_size
83
190
84
191
85
192
class CustomQwen3DecoderLayer (nn .Module ):
@@ -93,6 +200,9 @@ def __init__(
93
200
) -> None :
94
201
super ().__init__ ()
95
202
self .hidden_size = config .hidden_size
203
+ self .tp_size = get_tensor_model_parallel_world_size ()
204
+ self .tp_rank = get_tensor_model_parallel_rank ()
205
+ self .enable_fc = envs .VLLM_ASCEND_ENABLE_FLASHCOMM
96
206
# Requires transformers > 4.32.0
97
207
rope_theta = getattr (config , "rope_theta" , 1000000 )
98
208
rope_scaling = getattr (config , "rope_scaling" , None )
@@ -121,7 +231,7 @@ def __init__(
121
231
prefix = f"{ prefix } .self_attn" ,
122
232
attn_type = attn_type ,
123
233
)
124
- self .mlp = Qwen3MLP (
234
+ self .mlp = CustomQwen3MLP (
125
235
hidden_size = self .hidden_size ,
126
236
intermediate_size = config .intermediate_size ,
127
237
hidden_act = config .hidden_act ,
@@ -159,31 +269,56 @@ def __init__(
159
269
self .post_attention_layernorm = RMSNorm (
160
270
config .hidden_size , eps = config .rms_norm_eps )
161
271
162
- def forward (
163
- self ,
164
- positions : torch .Tensor ,
165
- cos : torch .Tensor ,
166
- sin : torch .Tensor ,
167
- hidden_states : torch .Tensor ,
168
- residual : Optional [torch .Tensor ],
169
- ) -> tuple [torch .Tensor , torch .Tensor ]:
272
+ def pre_attention_process (self , hidden_states , residual , pad_size = 0 ):
273
+ hidden_states , residual = self .input_layernorm (hidden_states , residual )
274
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
275
+ hidden_states = unpad (hidden_states , pad_size )
276
+ return hidden_states , residual
277
+
278
+ def pre_mlp_process (self , hidden_states , residual , pad_size = 0 ):
279
+ hidden_states , residual = self .post_attention_layernorm (
280
+ hidden_states , residual )
281
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
282
+ hidden_states = unpad (hidden_states , pad_size )
283
+ return hidden_states , residual
284
+
285
+ def forward (self ,
286
+ positions : torch .Tensor ,
287
+ hidden_states : torch .Tensor ,
288
+ residual : Optional [torch .Tensor ],
289
+ cos : torch .Tensor ,
290
+ sin : torch .Tensor ,
291
+ pad_size : int = 0 ) -> tuple [torch .Tensor , torch .Tensor , int ]:
170
292
# Self Attention
171
293
if residual is None :
172
294
residual = hidden_states
173
295
hidden_states = self .input_layernorm (hidden_states )
296
+ if self .enable_fc == 2 :
297
+ residual , pad_size = pad (residual , self .tp_size )
298
+ chunk_size = residual .size (0 ) // self .tp_size
299
+ residual = residual [chunk_size * self .tp_rank :chunk_size *
300
+ (self .tp_rank + 1 )]
301
+ else :
302
+ if self .enable_fc == 2 :
303
+ hidden_states , residual = self .pre_attention_process (
304
+ hidden_states , residual , pad_size )
305
+ else :
306
+ hidden_states , residual = self .input_layernorm (
307
+ hidden_states , residual )
308
+ hidden_states , pad_size = self .self_attn (positions = positions ,
309
+ hidden_states = hidden_states ,
310
+ cos = cos ,
311
+ sin = sin )
312
+
313
+ # Fully Connected
314
+ if self .enable_fc == 2 :
315
+ hidden_states , residual = self .pre_mlp_process (
316
+ hidden_states , residual , pad_size )
174
317
else :
175
- hidden_states , residual = self .input_layernorm (
318
+ hidden_states , residual = self .post_attention_layernorm (
176
319
hidden_states , residual )
177
- hidden_states = self .self_attn (
178
- positions = positions ,
179
- cos = cos ,
180
- sin = sin ,
181
- hidden_states = hidden_states ,
182
- )
183
- hidden_states , residual = self .post_attention_layernorm (
184
- hidden_states , residual )
185
- hidden_states = self .mlp (hidden_states )
186
- return hidden_states , residual
320
+ hidden_states , pad_size = self .mlp (hidden_states )
321
+ return hidden_states , residual , pad_size
187
322
188
323
189
324
ALL_DECODER_LAYER_TYPES = {
@@ -207,6 +342,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
207
342
prefix = prefix ,
208
343
decoder_layer_type = CustomQwen3DecoderLayer )
209
344
self .cos_sin_cache = self .layers [0 ].self_attn .rotary_emb .cos_sin_cache
345
+ self .tp_size = get_tensor_model_parallel_world_size ()
346
+ self .tp_rank = get_tensor_model_parallel_rank ()
347
+ self .enable_fc = envs .VLLM_ASCEND_ENABLE_FLASHCOMM
210
348
211
349
def forward (
212
350
self ,
@@ -235,20 +373,25 @@ def forward(
235
373
cos , sin = cos .view (1 , - 1 , 1 , last_dim ).contiguous (), sin .view (
236
374
1 , - 1 , 1 , last_dim ).contiguous ()
237
375
376
+ pad_size = 0
238
377
for layer in self .layers [self .start_layer :self .end_layer ]:
239
- hidden_states , residual = layer (
240
- positions ,
241
- cos ,
242
- sin ,
243
- hidden_states ,
244
- residual ,
245
- )
378
+ hidden_states , residual , pad_size = layer (positions , hidden_states ,
379
+ residual , cos , sin ,
380
+ pad_size )
381
+
246
382
if not get_pp_group ().is_last_rank :
247
383
return IntermediateTensors ({
248
384
"hidden_states" : hidden_states ,
249
385
"residual" : residual
250
386
})
251
387
hidden_states , _ = self .norm (hidden_states , residual )
388
+
389
+ if self .enable_fc == 2 :
390
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
391
+ residual = tensor_model_parallel_all_gather (residual , 0 )
392
+ if pad_size > 0 :
393
+ hidden_states = hidden_states [:- pad_size ]
394
+ residual = residual [:- pad_size ]
252
395
return hidden_states
253
396
254
397
0 commit comments