Skip to content

Commit 80564c7

Browse files
author
weijinqian_v1
committed
[Feature] qwen3_moe qk norm support multi_stream.
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent 53ecd89 commit 80564c7

File tree

1 file changed

+55
-2
lines changed

1 file changed

+55
-2
lines changed

vllm_ascend/models/qwen3_moe.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# Adapted from vllm/model_executor/models/qwen3_moe.py
1818
# This file is a part of the vllm-ascend project.
1919

20-
from typing import Optional, Union
20+
from typing import Any, Optional, Union
2121

2222
import torch
2323
from torch import nn
@@ -50,6 +50,7 @@
5050
from vllm_ascend.ops.fused_moe import AscendFusedMoE
5151
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
5252
init_metadata_for_sp)
53+
from vllm_ascend.utils import npu_stream_switch
5354

5455

5556
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -125,6 +126,58 @@ def forward(
125126
return hidden_states
126127

127128

129+
class CustomQwen3MoeAttention(Qwen3MoeAttention):
130+
131+
def __init__(
132+
self,
133+
hidden_size: int,
134+
num_heads: int,
135+
num_kv_heads: int,
136+
rope_theta: float = 10000,
137+
rope_scaling: Optional[dict[str, Any]] = None,
138+
max_position_embeddings: int = 8192,
139+
head_dim: Optional[int] = None,
140+
rms_norm_eps: float = 1e-06,
141+
qkv_bias: bool = False,
142+
cache_config: Optional[CacheConfig] = None,
143+
quant_config: Optional[QuantizationConfig] = None,
144+
prefix: str = "",
145+
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
146+
) -> None:
147+
super().__init__(hidden_size, num_heads, num_kv_heads, rope_theta,
148+
rope_scaling, max_position_embeddings, head_dim,
149+
rms_norm_eps, qkv_bias, cache_config, quant_config,
150+
prefix, dual_chunk_attention_config)
151+
self.alt_stream = torch.npu.Stream()
152+
153+
def forward(
154+
self,
155+
positions: torch.Tensor,
156+
hidden_states: torch.Tensor,
157+
) -> torch.Tensor:
158+
qkv, _ = self.qkv_proj(hidden_states)
159+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
160+
161+
self.alt_stream.wait_stream(torch.npu.current_stream())
162+
with npu_stream_switch(self.alt_stream):
163+
# Add qk-norm
164+
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
165+
self.head_dim)
166+
q_by_head = self.q_norm(q_by_head)
167+
q = q_by_head.view(q.shape)
168+
169+
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
170+
self.head_dim)
171+
k_by_head = self.k_norm(k_by_head)
172+
k = k_by_head.view(k.shape)
173+
174+
torch.npu.current_stream().wait_stream(self.alt_stream)
175+
q, k = self.rotary_emb(positions, q, k)
176+
attn_output = self.attn(q, k, v)
177+
output, _ = self.o_proj(attn_output)
178+
return output
179+
180+
128181
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
129182

130183
def __init__(
@@ -142,7 +195,7 @@ def __init__(
142195
rope_scaling = getattr(config, "rope_scaling", None)
143196
max_position_embeddings = getattr(config, "max_position_embeddings",
144197
8192)
145-
self.self_attn = Qwen3MoeAttention(
198+
self.self_attn = CustomQwen3MoeAttention(
146199
hidden_size=self.hidden_size,
147200
num_heads=config.num_attention_heads,
148201
num_kv_heads=config.num_key_value_heads,

0 commit comments

Comments
 (0)