4
4
import torch
5
5
from torch import nn
6
6
from transformers import Qwen3Config
7
+ from vllm .attention import AttentionType
7
8
from vllm .compilation .decorators import support_torch_compile
8
9
from vllm .config import CacheConfig , VllmConfig
9
10
from vllm .distributed import get_pp_group
11
+ from vllm .model_executor .layers .layernorm import RMSNorm
10
12
from vllm .model_executor .layers .logits_processor import LogitsProcessor
11
13
from vllm .model_executor .layers .quantization import QuantizationConfig
12
14
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
13
15
from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
14
16
from vllm .model_executor .models .qwen2 import Qwen2Model
15
- from vllm .model_executor .models .qwen3 import Qwen3DecoderLayer
17
+ from vllm .model_executor .models .qwen3 import Qwen3Attention , Qwen3MLP
16
18
from vllm .model_executor .models .utils import (AutoWeightsLoader ,
17
19
PPMissingLayer , maybe_prefix )
18
20
from vllm .model_executor .sampling_metadata import SamplingMetadata
21
23
from vllm_ascend .ops .layernorm import AddRMSNormW8A8Quant
22
24
23
25
24
- class CustomQwen3DecoderLayer (Qwen3DecoderLayer ):
26
+ class CustomQwen3Attention (Qwen3Attention ):
27
+
28
+ def __init__ (self ,
29
+ hidden_size : int ,
30
+ num_heads : int ,
31
+ num_kv_heads : int ,
32
+ max_position : int = 4096 * 32 ,
33
+ head_dim : Optional [int ] = None ,
34
+ rms_norm_eps : float = 1e-06 ,
35
+ qkv_bias : bool = False ,
36
+ rope_theta : float = 10000 ,
37
+ cache_config : Optional [CacheConfig ] = None ,
38
+ quant_config : Optional [QuantizationConfig ] = None ,
39
+ rope_scaling : Optional [tuple ] = None ,
40
+ prefix : str = "" ,
41
+ attn_type : str = AttentionType .DECODER ) -> None :
42
+ super ().__init__ (hidden_size = hidden_size ,
43
+ num_heads = num_heads ,
44
+ num_kv_heads = num_kv_heads ,
45
+ max_position = max_position ,
46
+ head_dim = head_dim ,
47
+ rms_norm_eps = rms_norm_eps ,
48
+ qkv_bias = qkv_bias ,
49
+ rope_theta = rope_theta ,
50
+ cache_config = cache_config ,
51
+ quant_config = quant_config ,
52
+ rope_scaling = rope_scaling ,
53
+ prefix = prefix ,
54
+ attn_type = attn_type )
55
+
56
+ def forward (
57
+ self ,
58
+ positions : torch .Tensor ,
59
+ cos : torch .Tensor ,
60
+ sin : torch .Tensor ,
61
+ hidden_states : torch .Tensor ,
62
+ ) -> torch .Tensor :
63
+ qkv , _ = self .qkv_proj (hidden_states )
64
+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
65
+ # Add qk-norm
66
+ q_by_head = q .view (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim ,
67
+ self .head_dim )
68
+ q_by_head = self .q_norm (q_by_head )
69
+ q = q_by_head .view (q .shape )
70
+ k_by_head = k .view (* k .shape [:- 1 ], k .shape [- 1 ] // self .head_dim ,
71
+ self .head_dim )
72
+ k_by_head = self .k_norm (k_by_head )
73
+ k = k_by_head .view (k .shape )
74
+ q , k = self .rotary_emb (positions ,
75
+ q ,
76
+ k ,
77
+ cos = cos ,
78
+ sin = sin ,
79
+ skip_index_select = True )
80
+ attn_output = self .attn (q , k , v )
81
+ output , _ = self .o_proj (attn_output )
82
+ return output
83
+
84
+
85
+ class CustomQwen3DecoderLayer (nn .Module ):
25
86
26
87
def __init__ (
27
88
self ,
@@ -30,11 +91,48 @@ def __init__(
30
91
quant_config : Optional [QuantizationConfig ] = None ,
31
92
prefix : str = "" ,
32
93
) -> None :
33
- super ().__init__ (config = config ,
34
- cache_config = cache_config ,
35
- quant_config = quant_config ,
36
- prefix = prefix )
94
+ super ().__init__ ()
95
+ self .hidden_size = config .hidden_size
96
+ # Requires transformers > 4.32.0
97
+ rope_theta = getattr (config , "rope_theta" , 1000000 )
98
+ rope_scaling = getattr (config , "rope_scaling" , None )
99
+
100
+ # By default, Qwen3 uses causal attention as it is a decoder-only model.
101
+ # You can override the HF config with `is_causal=False` to enable
102
+ # bidirectional attention, which is used in some embedding models
103
+ # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
104
+ if getattr (config , "is_causal" , True ):
105
+ attn_type = AttentionType .DECODER
106
+ else :
107
+ attn_type = AttentionType .ENCODER_ONLY
108
+
109
+ self .self_attn = CustomQwen3Attention (
110
+ hidden_size = self .hidden_size ,
111
+ num_heads = config .num_attention_heads ,
112
+ max_position = config .max_position_embeddings ,
113
+ num_kv_heads = config .num_key_value_heads ,
114
+ rope_theta = rope_theta ,
115
+ rms_norm_eps = config .rms_norm_eps ,
116
+ qkv_bias = getattr (config , 'attention_bias' , False ),
117
+ head_dim = getattr (config , 'head_dim' , None ),
118
+ cache_config = cache_config ,
119
+ quant_config = quant_config ,
120
+ rope_scaling = rope_scaling ,
121
+ prefix = f"{ prefix } .self_attn" ,
122
+ attn_type = attn_type ,
123
+ )
124
+ self .mlp = Qwen3MLP (
125
+ hidden_size = self .hidden_size ,
126
+ intermediate_size = config .intermediate_size ,
127
+ hidden_act = config .hidden_act ,
128
+ quant_config = quant_config ,
129
+ prefix = f"{ prefix } .mlp" ,
130
+ )
37
131
if quant_config is None :
132
+ self .input_layernorm = RMSNorm (config .hidden_size ,
133
+ eps = config .rms_norm_eps )
134
+ self .post_attention_layernorm = RMSNorm (config .hidden_size ,
135
+ eps = config .rms_norm_eps )
38
136
return
39
137
40
138
from vllm_ascend .quantization .quant_config import AscendQuantConfig
@@ -56,6 +154,32 @@ def __init__(
56
154
layer = self .mlp .gate_up_proj ,
57
155
eps = config .rms_norm_eps )
58
156
157
+ def forward (
158
+ self ,
159
+ positions : torch .Tensor ,
160
+ cos : torch .Tensor ,
161
+ sin : torch .Tensor ,
162
+ hidden_states : torch .Tensor ,
163
+ residual : Optional [torch .Tensor ],
164
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
165
+ # Self Attention
166
+ if residual is None :
167
+ residual = hidden_states
168
+ hidden_states = self .input_layernorm (hidden_states )
169
+ else :
170
+ hidden_states , residual = self .input_layernorm (
171
+ hidden_states , residual )
172
+ hidden_states = self .self_attn (
173
+ positions = positions ,
174
+ cos = cos ,
175
+ sin = sin ,
176
+ hidden_states = hidden_states ,
177
+ )
178
+ hidden_states , residual = self .post_attention_layernorm (
179
+ hidden_states , residual )
180
+ hidden_states = self .mlp (hidden_states )
181
+ return hidden_states , residual
182
+
59
183
60
184
ALL_DECODER_LAYER_TYPES = {
61
185
"attention" : CustomQwen3DecoderLayer ,
@@ -77,6 +201,50 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
77
201
super ().__init__ (vllm_config = vllm_config ,
78
202
prefix = prefix ,
79
203
decoder_layer_type = CustomQwen3DecoderLayer )
204
+ self .cos_sin_cache = self .layers [0 ].self_attn .rotary_emb .cos_sin_cache
205
+
206
+ def forward (
207
+ self ,
208
+ input_ids : torch .Tensor ,
209
+ positions : torch .Tensor ,
210
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
211
+ inputs_embeds : Optional [torch .Tensor ] = None ,
212
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
213
+ if get_pp_group ().is_first_rank :
214
+ if inputs_embeds is not None :
215
+ hidden_states = inputs_embeds
216
+ else :
217
+ hidden_states = self .get_input_embeddings (input_ids )
218
+ residual = None
219
+ else :
220
+ assert intermediate_tensors is not None
221
+ hidden_states = intermediate_tensors ["hidden_states" ]
222
+ residual = intermediate_tensors ["residual" ]
223
+
224
+ cos_sin = self .cos_sin_cache .index_select (0 , positions )
225
+ last_dim = cos_sin .size ()[- 1 ]
226
+ cos , sin = cos_sin .reshape (- 1 , 2 ,
227
+ last_dim // 2 ).repeat (1 , 1 , 2 ).chunk (2 ,
228
+ dim = - 2 )
229
+ # BSNH
230
+ cos , sin = cos .view (1 , - 1 , 1 , last_dim ).contiguous (), sin .view (
231
+ 1 , - 1 , 1 , last_dim ).contiguous ()
232
+
233
+ for layer in self .layers [self .start_layer :self .end_layer ]:
234
+ hidden_states , residual = layer (
235
+ positions ,
236
+ cos ,
237
+ sin ,
238
+ hidden_states ,
239
+ residual ,
240
+ )
241
+ if not get_pp_group ().is_last_rank :
242
+ return IntermediateTensors ({
243
+ "hidden_states" : hidden_states ,
244
+ "residual" : residual
245
+ })
246
+ hidden_states , _ = self .norm (hidden_states , residual )
247
+ return hidden_states
80
248
81
249
82
250
class CustomQwen3ForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
0 commit comments