17
17
# Adapted from vllm/model_executor/models/qwen3_moe.py
18
18
# This file is a part of the vllm-ascend project.
19
19
20
- from typing import Optional , Union
20
+ from typing import Any , Optional , Union
21
21
22
22
import torch
23
23
from torch import nn
50
50
from vllm_ascend .ops .fused_moe import AscendFusedMoE
51
51
from vllm_ascend .ops .sequence_parallel import (MetadataForPadding ,
52
52
init_metadata_for_sp )
53
+ from vllm_ascend .utils import npu_stream_switch
53
54
54
55
55
56
class CustomSparseMoeBlock (Qwen3MoeSparseMoeBlock ):
@@ -125,6 +126,58 @@ def forward(
125
126
return hidden_states
126
127
127
128
129
+ class CustomQwen3MoeAttention (Qwen3MoeAttention ):
130
+
131
+ def __init__ (
132
+ self ,
133
+ hidden_size : int ,
134
+ num_heads : int ,
135
+ num_kv_heads : int ,
136
+ rope_theta : float = 10000 ,
137
+ rope_scaling : Optional [dict [str , Any ]] = None ,
138
+ max_position_embeddings : int = 8192 ,
139
+ head_dim : Optional [int ] = None ,
140
+ rms_norm_eps : float = 1e-06 ,
141
+ qkv_bias : bool = False ,
142
+ cache_config : Optional [CacheConfig ] = None ,
143
+ quant_config : Optional [QuantizationConfig ] = None ,
144
+ prefix : str = "" ,
145
+ dual_chunk_attention_config : Optional [dict [str , Any ]] = None ,
146
+ ) -> None :
147
+ super ().__init__ (hidden_size , num_heads , num_kv_heads , rope_theta ,
148
+ rope_scaling , max_position_embeddings , head_dim ,
149
+ rms_norm_eps , qkv_bias , cache_config , quant_config ,
150
+ prefix , dual_chunk_attention_config )
151
+ self .alt_stream = torch .npu .Stream ()
152
+
153
+ def forward (
154
+ self ,
155
+ positions : torch .Tensor ,
156
+ hidden_states : torch .Tensor ,
157
+ ) -> torch .Tensor :
158
+ qkv , _ = self .qkv_proj (hidden_states )
159
+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
160
+
161
+ self .alt_stream .wait_stream (torch .npu .current_stream ())
162
+ with npu_stream_switch (self .alt_stream ):
163
+ # Add qk-norm
164
+ q_by_head = q .view (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim ,
165
+ self .head_dim )
166
+ q_by_head = self .q_norm (q_by_head )
167
+ q = q_by_head .view (q .shape )
168
+
169
+ k_by_head = k .view (* k .shape [:- 1 ], k .shape [- 1 ] // self .head_dim ,
170
+ self .head_dim )
171
+ k_by_head = self .k_norm (k_by_head )
172
+ k = k_by_head .view (k .shape )
173
+
174
+ torch .npu .current_stream ().wait_stream (self .alt_stream )
175
+ q , k = self .rotary_emb (positions , q , k )
176
+ attn_output = self .attn (q , k , v )
177
+ output , _ = self .o_proj (attn_output )
178
+ return output
179
+
180
+
128
181
class CustomQwen3MoeDecoderLayer (Qwen3MoeDecoderLayer ):
129
182
130
183
def __init__ (
@@ -142,7 +195,7 @@ def __init__(
142
195
rope_scaling = getattr (config , "rope_scaling" , None )
143
196
max_position_embeddings = getattr (config , "max_position_embeddings" ,
144
197
8192 )
145
- self .self_attn = Qwen3MoeAttention (
198
+ self .self_attn = CustomQwen3MoeAttention (
146
199
hidden_size = self .hidden_size ,
147
200
num_heads = config .num_attention_heads ,
148
201
num_kv_heads = config .num_key_value_heads ,
0 commit comments