4
4
import torch
5
5
from torch import nn
6
6
from transformers import Qwen3Config
7
+ from vllm .attention import AttentionType
7
8
from vllm .compilation .decorators import support_torch_compile
8
9
from vllm .config import CacheConfig , VllmConfig
9
10
from vllm .distributed import get_pp_group
11
+ from vllm .model_executor .layers .layernorm import RMSNorm
10
12
from vllm .model_executor .layers .logits_processor import LogitsProcessor
11
13
from vllm .model_executor .layers .quantization import QuantizationConfig
12
14
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
13
15
from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
14
16
from vllm .model_executor .models .qwen2 import Qwen2Model
15
- from vllm .model_executor .models .qwen3 import Qwen3DecoderLayer
17
+ from vllm .model_executor .models .qwen3 import Qwen3Attention , Qwen3MLP
16
18
from vllm .model_executor .models .utils import (AutoWeightsLoader ,
17
19
PPMissingLayer , maybe_prefix )
18
20
from vllm .model_executor .sampling_metadata import SamplingMetadata
21
23
from vllm_ascend .ops .layernorm import AddRMSNormW8A8Quant
22
24
23
25
24
- class CustomQwen3DecoderLayer (Qwen3DecoderLayer ):
26
+ class CustomQwen3Attention (Qwen3Attention ):
27
+
28
+ def __init__ (self ,
29
+ hidden_size : int ,
30
+ num_heads : int ,
31
+ num_kv_heads : int ,
32
+ max_position : int = 4096 * 32 ,
33
+ head_dim : Optional [int ] = None ,
34
+ rms_norm_eps : float = 1e-06 ,
35
+ qkv_bias : bool = False ,
36
+ rope_theta : float = 10000 ,
37
+ cache_config : Optional [CacheConfig ] = None ,
38
+ quant_config : Optional [QuantizationConfig ] = None ,
39
+ rope_scaling : Optional [tuple ] = None ,
40
+ prefix : str = "" ,
41
+ attn_type : str = AttentionType .DECODER ) -> None :
42
+ super ().__init__ (hidden_size = hidden_size ,
43
+ num_heads = num_heads ,
44
+ num_kv_heads = num_kv_heads ,
45
+ max_position = max_position ,
46
+ head_dim = head_dim ,
47
+ rms_norm_eps = rms_norm_eps ,
48
+ qkv_bias = qkv_bias ,
49
+ rope_theta = rope_theta ,
50
+ cache_config = cache_config ,
51
+ quant_config = quant_config ,
52
+ rope_scaling = rope_scaling ,
53
+ prefix = prefix ,
54
+ attn_type = attn_type )
55
+
56
+ def forward (
57
+ self ,
58
+ positions : torch .Tensor ,
59
+ cos : torch .Tensor ,
60
+ sin : torch .Tensor ,
61
+ hidden_states : torch .Tensor ,
62
+ ) -> torch .Tensor :
63
+ qkv , _ = self .qkv_proj (hidden_states )
64
+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
65
+ # Add qk-norm
66
+ q_by_head = q .view (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim ,
67
+ self .head_dim )
68
+ q_by_head = self .q_norm (q_by_head )
69
+ q = q_by_head .view (q .shape )
70
+ k_by_head = k .view (* k .shape [:- 1 ], k .shape [- 1 ] // self .head_dim ,
71
+ self .head_dim )
72
+ k_by_head = self .k_norm (k_by_head )
73
+ k = k_by_head .view (k .shape )
74
+ q , k = self .rotary_emb (positions ,
75
+ q ,
76
+ k ,
77
+ cos = cos ,
78
+ sin = sin ,
79
+ skip_index_select = True )
80
+ attn_output = self .attn (q , k , v )
81
+ output , _ = self .o_proj (attn_output )
82
+ return output
83
+
84
+
85
+ class CustomQwen3DecoderLayer (nn .Module ):
25
86
26
87
def __init__ (
27
88
self ,
@@ -30,31 +91,99 @@ def __init__(
30
91
quant_config : Optional [QuantizationConfig ] = None ,
31
92
prefix : str = "" ,
32
93
) -> None :
33
- super ().__init__ (config = config ,
34
- cache_config = cache_config ,
35
- quant_config = quant_config ,
36
- prefix = prefix )
94
+ super ().__init__ ()
95
+ self .hidden_size = config .hidden_size
96
+ # Requires transformers > 4.32.0
97
+ rope_theta = getattr (config , "rope_theta" , 1000000 )
98
+ rope_scaling = getattr (config , "rope_scaling" , None )
99
+
100
+ # By default, Qwen3 uses causal attention as it is a decoder-only model.
101
+ # You can override the HF config with `is_causal=False` to enable
102
+ # bidirectional attention, which is used in some embedding models
103
+ # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
104
+ if getattr (config , "is_causal" , True ):
105
+ attn_type = AttentionType .DECODER
106
+ else :
107
+ attn_type = AttentionType .ENCODER_ONLY
108
+
109
+ self .self_attn = CustomQwen3Attention (
110
+ hidden_size = self .hidden_size ,
111
+ num_heads = config .num_attention_heads ,
112
+ max_position = config .max_position_embeddings ,
113
+ num_kv_heads = config .num_key_value_heads ,
114
+ rope_theta = rope_theta ,
115
+ rms_norm_eps = config .rms_norm_eps ,
116
+ qkv_bias = getattr (config , 'attention_bias' , False ),
117
+ head_dim = getattr (config , 'head_dim' , None ),
118
+ cache_config = cache_config ,
119
+ quant_config = quant_config ,
120
+ rope_scaling = rope_scaling ,
121
+ prefix = f"{ prefix } .self_attn" ,
122
+ attn_type = attn_type ,
123
+ )
124
+ self .mlp = Qwen3MLP (
125
+ hidden_size = self .hidden_size ,
126
+ intermediate_size = config .intermediate_size ,
127
+ hidden_act = config .hidden_act ,
128
+ quant_config = quant_config ,
129
+ prefix = f"{ prefix } .mlp" ,
130
+ )
37
131
if quant_config is None :
38
- return
132
+ self .input_layernorm = RMSNorm (config .hidden_size ,
133
+ eps = config .rms_norm_eps )
134
+ self .post_attention_layernorm = RMSNorm (config .hidden_size ,
135
+ eps = config .rms_norm_eps )
136
+ else :
137
+ from vllm_ascend .quantization .quant_config import AscendQuantConfig
138
+ from vllm_ascend .quantization .w8a8 import AscendW8A8LinearMethod
39
139
40
- from vllm_ascend . quantization . quant_config import AscendQuantConfig
41
- from vllm_ascend . quantization . w8a8 import AscendW8A8LinearMethod
140
+ assert isinstance ( quant_config , AscendQuantConfig ), \
141
+ "Expected quant_config to be an instance of AscendQuantConfig"
42
142
43
- assert isinstance (quant_config , AscendQuantConfig ), \
44
- "Expected quant_config to be an instance of AscendQuantConfig"
143
+ if isinstance (self .self_attn .qkv_proj .quant_method .quant_method ,
144
+ AscendW8A8LinearMethod ):
145
+ self .input_layernorm = AddRMSNormW8A8Quant (
146
+ config .hidden_size ,
147
+ layer = self .self_attn .qkv_proj ,
148
+ eps = config .rms_norm_eps )
149
+ else :
150
+ self .input_layernorm = RMSNorm (config .hidden_size ,
151
+ eps = config .rms_norm_eps )
152
+ if isinstance (self .mlp .gate_up_proj .quant_method .quant_method ,
153
+ AscendW8A8LinearMethod ):
154
+ self .post_attention_layernorm = AddRMSNormW8A8Quant (
155
+ config .hidden_size ,
156
+ layer = self .mlp .gate_up_proj ,
157
+ eps = config .rms_norm_eps )
158
+ else :
159
+ self .post_attention_layernorm = RMSNorm (config .hidden_size ,
160
+ eps = config .rms_norm_eps )
45
161
46
- if isinstance (self .self_attn .qkv_proj .quant_method .quant_method ,
47
- AscendW8A8LinearMethod ):
48
- self .input_layernorm = AddRMSNormW8A8Quant (
49
- config .hidden_size ,
50
- layer = self .self_attn .qkv_proj ,
51
- eps = config .rms_norm_eps )
52
- if isinstance (self .mlp .gate_up_proj .quant_method .quant_method ,
53
- AscendW8A8LinearMethod ):
54
- self .post_attention_layernorm = AddRMSNormW8A8Quant (
55
- config .hidden_size ,
56
- layer = self .mlp .gate_up_proj ,
57
- eps = config .rms_norm_eps )
162
+ def forward (
163
+ self ,
164
+ positions : torch .Tensor ,
165
+ cos : torch .Tensor ,
166
+ sin : torch .Tensor ,
167
+ hidden_states : torch .Tensor ,
168
+ residual : Optional [torch .Tensor ],
169
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
170
+ # Self Attention
171
+ if residual is None :
172
+ residual = hidden_states
173
+ hidden_states = self .input_layernorm (hidden_states )
174
+ else :
175
+ hidden_states , residual = self .input_layernorm (
176
+ hidden_states , residual )
177
+ hidden_states = self .self_attn (
178
+ positions = positions ,
179
+ cos = cos ,
180
+ sin = sin ,
181
+ hidden_states = hidden_states ,
182
+ )
183
+ hidden_states , residual = self .post_attention_layernorm (
184
+ hidden_states , residual )
185
+ hidden_states = self .mlp (hidden_states )
186
+ return hidden_states , residual
58
187
59
188
60
189
ALL_DECODER_LAYER_TYPES = {
@@ -77,6 +206,50 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
77
206
super ().__init__ (vllm_config = vllm_config ,
78
207
prefix = prefix ,
79
208
decoder_layer_type = CustomQwen3DecoderLayer )
209
+ self .cos_sin_cache = self .layers [0 ].self_attn .rotary_emb .cos_sin_cache
210
+
211
+ def forward (
212
+ self ,
213
+ input_ids : torch .Tensor ,
214
+ positions : torch .Tensor ,
215
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
216
+ inputs_embeds : Optional [torch .Tensor ] = None ,
217
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
218
+ if get_pp_group ().is_first_rank :
219
+ if inputs_embeds is not None :
220
+ hidden_states = inputs_embeds
221
+ else :
222
+ hidden_states = self .get_input_embeddings (input_ids )
223
+ residual = None
224
+ else :
225
+ assert intermediate_tensors is not None
226
+ hidden_states = intermediate_tensors ["hidden_states" ]
227
+ residual = intermediate_tensors ["residual" ]
228
+
229
+ cos_sin = self .cos_sin_cache .index_select (0 , positions )
230
+ last_dim = cos_sin .size ()[- 1 ]
231
+ cos , sin = cos_sin .reshape (- 1 , 2 ,
232
+ last_dim // 2 ).repeat (1 , 1 , 2 ).chunk (2 ,
233
+ dim = - 2 )
234
+ # BSNH
235
+ cos , sin = cos .view (1 , - 1 , 1 , last_dim ).contiguous (), sin .view (
236
+ 1 , - 1 , 1 , last_dim ).contiguous ()
237
+
238
+ for layer in self .layers [self .start_layer :self .end_layer ]:
239
+ hidden_states , residual = layer (
240
+ positions ,
241
+ cos ,
242
+ sin ,
243
+ hidden_states ,
244
+ residual ,
245
+ )
246
+ if not get_pp_group ().is_last_rank :
247
+ return IntermediateTensors ({
248
+ "hidden_states" : hidden_states ,
249
+ "residual" : residual
250
+ })
251
+ hidden_states , _ = self .norm (hidden_states , residual )
252
+ return hidden_states
80
253
81
254
82
255
class CustomQwen3ForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
0 commit comments