3
3
4
4
import torch
5
5
from torch import nn
6
+ import torch .nn .functional as F
7
+ import torch .distributed as dist
6
8
from transformers import Qwen3Config
7
9
from vllm .compilation .decorators import support_torch_compile
8
10
from vllm .config import CacheConfig , VllmConfig
9
- from vllm .distributed import get_pp_group
11
+ from vllm .attention import Attention , AttentionType
12
+ from vllm .distributed import (get_pp_group ,
13
+ get_tensor_model_parallel_world_size ,
14
+ get_tensor_model_parallel_rank ,
15
+ tensor_model_parallel_all_gather )
10
16
from vllm .model_executor .layers .logits_processor import LogitsProcessor
11
17
from vllm .model_executor .layers .quantization import QuantizationConfig
12
18
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
19
+ from vllm .model_executor .layers .linear import RowParallelLinear , ReplicatedLinear
13
20
from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
14
21
from vllm .model_executor .models .qwen2 import Qwen2Model
15
- from vllm .model_executor .models .qwen3 import Qwen3DecoderLayer
22
+ from vllm .model_executor .models .qwen2 import Qwen2MLP as Qwen3MLP
23
+ from vllm .model_executor .models .qwen3 import Qwen3DecoderLayer , Qwen3Attention
16
24
from vllm .model_executor .models .utils import (AutoWeightsLoader ,
17
25
PPMissingLayer , maybe_prefix )
18
26
from vllm .model_executor .sampling_metadata import SamplingMetadata
19
27
from vllm .sequence import IntermediateTensors
20
28
29
+ from vllm_ascend import envs
21
30
from vllm_ascend .ops .layernorm import AddRMSNormQuant
22
31
23
32
24
- class CustomQwen3DecoderLayer (Qwen3DecoderLayer ):
33
+ def pad (tensor , x ):
34
+ length = tensor .size (0 )
35
+ pad_size = (x - (length % x )) % x
36
+ if pad_size > 0 :
37
+ return F .pad (tensor , (0 , 0 , 0 , pad_size )), pad_size
38
+ return tensor , pad_size
39
+
40
+ def unpad (tensor , pad_size ):
41
+ if pad_size > 0 :
42
+ return tensor [:- pad_size , :]
43
+ return tensor
44
+
45
+
46
+ class CustomQwen3MLP (Qwen3MLP ):
25
47
26
48
def __init__ (
27
49
self ,
28
- config : Qwen3Config ,
29
- cache_config : Optional [CacheConfig ] = None ,
50
+ hidden_size : int ,
51
+ intermediate_size : int ,
52
+ hidden_act : str ,
30
53
quant_config : Optional [QuantizationConfig ] = None ,
31
54
prefix : str = "" ,
32
55
) -> None :
33
- super ().__init__ (config = config ,
34
- cache_config = cache_config ,
56
+ super ().__init__ (hidden_size = hidden_size ,
57
+ intermediate_size = intermediate_size ,
58
+ hidden_act = hidden_act ,
35
59
quant_config = quant_config ,
36
60
prefix = prefix )
61
+ self .tp_size = get_tensor_model_parallel_world_size ()
62
+ self .tp_rank = get_tensor_model_parallel_rank ()
63
+ self .enable_fc = envs .VLLM_ENABLE_FC
64
+ if self .enable_fc :
65
+ # if flashcomm2 enbaled, replace Linear+AllReduce with All2All+Linear
66
+ self .down_proj = ReplicatedLinear (
67
+ intermediate_size ,
68
+ hidden_size ,
69
+ bias = False ,
70
+ quant_config = quant_config ,
71
+ prefix = f"{ prefix } .down_proj" ,
72
+ )
73
+ else :
74
+ self .down_proj = RowParallelLinear (
75
+ intermediate_size ,
76
+ hidden_size ,
77
+ bias = False ,
78
+ quant_config = quant_config ,
79
+ prefix = f"{ prefix } .down_proj" ,
80
+ )
81
+
82
+ def forward (self , x ):
83
+ gate_up , _ = self .gate_up_proj (x )
84
+ x = self .act_fn (gate_up )
85
+ pad_size = 0
86
+ if self .enable_fc :
87
+ # pad input because AllGather requires token_num to be divisible by tp_size
88
+ x , pad_size = pad (x , self .tp_size )
89
+ output = torch .empty (x .shape , dtype = x .dtype , device = x .device )
90
+ dist .all_to_all_single (output , x )
91
+ x = output .reshape (self .tp_size , - 1 , output .size (- 1 )) \
92
+ .transpose (0 , 1 ) \
93
+ .reshape (- 1 , output .size (- 1 )* self .tp_size )
94
+ x , _ = self .down_proj (x )
95
+ return x , pad_size
96
+
97
+
98
+ class CustomQwen3Attention (Qwen3Attention ):
99
+
100
+ def __init__ (self ,
101
+ hidden_size : int ,
102
+ num_heads : int ,
103
+ num_kv_heads : int ,
104
+ max_position : int = 4096 * 32 ,
105
+ head_dim : Optional [int ] = None ,
106
+ rms_norm_eps : float = 1e-06 ,
107
+ qkv_bias : bool = False ,
108
+ rope_theta : float = 10000 ,
109
+ cache_config : Optional [CacheConfig ] = None ,
110
+ quant_config : Optional [QuantizationConfig ] = None ,
111
+ rope_scaling : Optional [tuple ] = None ,
112
+ prefix : str = "" ,
113
+ attn_type : str = AttentionType .DECODER ) -> None :
114
+ super ().__init__ (hidden_size = hidden_size ,
115
+ num_heads = num_heads ,
116
+ num_kv_heads = num_kv_heads ,
117
+ max_position = max_position ,
118
+ head_dim = head_dim ,
119
+ rms_norm_eps = rms_norm_eps ,
120
+ qkv_bias = qkv_bias ,
121
+ rope_theta = rope_theta ,
122
+ cache_config = cache_config ,
123
+ quant_config = quant_config ,
124
+ rope_scaling = rope_scaling ,
125
+ prefix = prefix ,
126
+ attn_type = attn_type )
127
+ self .tp_size = get_tensor_model_parallel_world_size ()
128
+ self .tp_rank = get_tensor_model_parallel_rank ()
129
+ self .enable_fc = envs .VLLM_ENABLE_FC
130
+ if self .enable_fc :
131
+ self .o_proj = ReplicatedLinear (
132
+ self .total_num_heads * self .head_dim ,
133
+ hidden_size ,
134
+ bias = False ,
135
+ quant_config = quant_config ,
136
+ prefix = f"{ prefix } .o_proj" ,
137
+ )
138
+ else :
139
+ self .o_proj = RowParallelLinear (
140
+ self .total_num_heads * self .head_dim ,
141
+ hidden_size ,
142
+ bias = False ,
143
+ quant_config = quant_config ,
144
+ prefix = f"{ prefix } .o_proj" ,
145
+ )
146
+
147
+ def forward (
148
+ self ,
149
+ positions : torch .Tensor ,
150
+ hidden_states : torch .Tensor ,
151
+ ) -> torch .Tensor :
152
+ qkv , _ = self .qkv_proj (hidden_states )
153
+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
154
+ # Add qk-norm
155
+ q_by_head = q .view (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim ,
156
+ self .head_dim )
157
+ q_by_head = self .q_norm (q_by_head )
158
+ q = q_by_head .view (q .shape )
159
+ k_by_head = k .view (* k .shape [:- 1 ], k .shape [- 1 ] // self .head_dim ,
160
+ self .head_dim )
161
+ k_by_head = self .k_norm (k_by_head )
162
+ k = k_by_head .view (k .shape )
163
+ q , k = self .rotary_emb (positions , q , k )
164
+ attn_output = self .attn (q , k , v )
165
+ pad_size = 0
166
+ if self .enable_fc :
167
+ # pad input because AllGather requires token_num to be divisible by tp_size
168
+ attn_output , pad_size = pad (attn_output , self .tp_size )
169
+ output = torch .empty (attn_output .shape , dtype = attn_output .dtype , device = attn_output .device )
170
+ dist .all_to_all_single (output , attn_output )
171
+ attn_output = output .reshape (self .tp_size , - 1 , output .size (- 1 )) \
172
+ .transpose (0 , 1 ) \
173
+ .reshape (- 1 , output .size (- 1 )* self .tp_size )
174
+ output , _ = self .o_proj (attn_output )
175
+ return output , pad_size
176
+
177
+
178
+ class CustomQwen3DecoderLayer (nn .Module ):
179
+
180
+ def __init__ (
181
+ self ,
182
+ config : Qwen3Config ,
183
+ cache_config : Optional [CacheConfig ] = None ,
184
+ quant_config : Optional [QuantizationConfig ] = None ,
185
+ prefix : str = "" ,
186
+ ) -> None :
187
+ super ().__init__ ()
188
+ self .hidden_size = config .hidden_size
189
+ self .tp_size = get_tensor_model_parallel_world_size ()
190
+ self .tp_rank = get_tensor_model_parallel_rank ()
191
+ self .enable_fc = envs .VLLM_ENABLE_FC
192
+ # Requires transformers > 4.32.0
193
+ rope_theta = getattr (config , "rope_theta" , 1000000 )
194
+ rope_scaling = getattr (config , "rope_scaling" , None )
195
+
196
+ # By default, Qwen3 uses causal attention as it is a decoder-only model.
197
+ # You can override the HF config with `is_causal=False` to enable
198
+ # bidirectional attention, which is used in some embedding models
199
+ # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
200
+ if getattr (config , "is_causal" , True ):
201
+ attn_type = AttentionType .DECODER
202
+ else :
203
+ attn_type = AttentionType .ENCODER_ONLY
204
+
205
+ self .self_attn = CustomQwen3Attention (
206
+ hidden_size = self .hidden_size ,
207
+ num_heads = config .num_attention_heads ,
208
+ max_position = config .max_position_embeddings ,
209
+ num_kv_heads = config .num_key_value_heads ,
210
+ rope_theta = rope_theta ,
211
+ rms_norm_eps = config .rms_norm_eps ,
212
+ qkv_bias = getattr (config , 'attention_bias' , False ),
213
+ head_dim = getattr (config , 'head_dim' , None ),
214
+ cache_config = cache_config ,
215
+ quant_config = quant_config ,
216
+ rope_scaling = rope_scaling ,
217
+ prefix = f"{ prefix } .self_attn" ,
218
+ attn_type = attn_type ,
219
+ )
220
+ self .mlp = CustomQwen3MLP (
221
+ hidden_size = self .hidden_size ,
222
+ intermediate_size = config .intermediate_size ,
223
+ hidden_act = config .hidden_act ,
224
+ quant_config = quant_config ,
225
+ prefix = f"{ prefix } .mlp" ,
226
+ )
37
227
if quant_config is None :
38
228
return
39
229
@@ -56,6 +246,58 @@ def __init__(
56
246
layer = self .mlp .gate_up_proj ,
57
247
eps = config .rms_norm_eps )
58
248
249
+ def pre_attention_process (self , hidden_states , residual , pad_size = 0 ):
250
+ hidden_states , residual = self .input_layernorm (
251
+ hidden_states , residual )
252
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
253
+ hidden_states = unpad (hidden_states , pad_size )
254
+ return hidden_states , residual
255
+
256
+ def pre_mlp_process (self , hidden_states , residual , pad_size = 0 ):
257
+ token_num = hidden_states .size (0 )
258
+ if token_num != residual .size (0 ):
259
+ if pad_size > 0 :
260
+ residual = F .pad (residual , (0 , 0 , 0 , pad_size ))
261
+ split_size_list = [token_num ] * self .tp_size
262
+ residual = torch .split (residual , split_size_list )[self .tp_rank ]
263
+
264
+ hidden_states , residual = self .post_attention_layernorm (
265
+ hidden_states , residual )
266
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
267
+ hidden_states = unpad (hidden_states , pad_size )
268
+ return hidden_states , residual
269
+
270
+ def forward (
271
+ self ,
272
+ positions : torch .Tensor ,
273
+ hidden_states : torch .Tensor ,
274
+ residual : Optional [torch .Tensor ],
275
+ pad_size : int = 0
276
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
277
+ # Self Attention
278
+ if residual is None :
279
+ residual = hidden_states
280
+ hidden_states = self .input_layernorm (hidden_states )
281
+ else :
282
+ if self .enable_fc :
283
+ hidden_states , residual = self .pre_attention_process (hidden_states , residual , pad_size )
284
+ else :
285
+ hidden_states , residual = self .input_layernorm (
286
+ hidden_states , residual )
287
+ hidden_states , pad_size = self .self_attn (
288
+ positions = positions ,
289
+ hidden_states = hidden_states ,
290
+ )
291
+
292
+ # Fully Connected
293
+ if self .enable_fc :
294
+ hidden_states , residual = self .pre_mlp_process (hidden_states , residual , pad_size )
295
+ else :
296
+ hidden_states , residual = self .post_attention_layernorm (
297
+ hidden_states , residual )
298
+ hidden_states , pad_size = self .mlp (hidden_states )
299
+ return hidden_states , residual , pad_size
300
+
59
301
60
302
ALL_DECODER_LAYER_TYPES = {
61
303
"attention" : CustomQwen3DecoderLayer ,
@@ -77,6 +319,48 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
77
319
super ().__init__ (vllm_config = vllm_config ,
78
320
prefix = prefix ,
79
321
decoder_layer_type = CustomQwen3DecoderLayer )
322
+ self .tp_size = get_tensor_model_parallel_world_size ()
323
+ self .tp_rank = get_tensor_model_parallel_rank ()
324
+ self .enable_fc = envs .VLLM_ENABLE_FC
325
+
326
+ def forward (
327
+ self ,
328
+ input_ids : torch .Tensor ,
329
+ positions : torch .Tensor ,
330
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
331
+ inputs_embeds : Optional [torch .Tensor ] = None ,
332
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
333
+ if get_pp_group ().is_first_rank :
334
+ if inputs_embeds is not None :
335
+ hidden_states = inputs_embeds
336
+ else :
337
+ hidden_states = self .get_input_embeddings (input_ids )
338
+ residual = None
339
+ else :
340
+ assert intermediate_tensors is not None
341
+ hidden_states = intermediate_tensors ["hidden_states" ]
342
+ residual = intermediate_tensors ["residual" ]
343
+ pad_size = 0
344
+ for layer in self .layers [self .start_layer :self .end_layer ]:
345
+ hidden_states , residual , pad_size = layer (
346
+ positions ,
347
+ hidden_states ,
348
+ residual ,
349
+ pad_size
350
+ )
351
+ if self .enable_fc :
352
+ hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
353
+ residual = tensor_model_parallel_all_gather (residual , 0 )
354
+ if pad_size > 0 :
355
+ hidden_states = hidden_states [:- pad_size ]
356
+ residual = residual [:- pad_size ]
357
+ if not get_pp_group ().is_last_rank :
358
+ return IntermediateTensors ({
359
+ "hidden_states" : hidden_states ,
360
+ "residual" : residual
361
+ })
362
+ hidden_states , _ = self .norm (hidden_states , residual )
363
+ return hidden_states
80
364
81
365
82
366
class CustomQwen3ForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
0 commit comments