Skip to content

Commit e3ede4c

Browse files
author
weijinqian_v1
committed
[Feature] qwen3_moe qk norm support multi_stream.
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent 2f1dbe5 commit e3ede4c

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

vllm_ascend/models/qwen3.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from vllm_ascend import envs
3030
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
31+
from vllm_ascend.utils import npu_stream_switch
3132

3233

3334
def pad(tensor, x):
@@ -130,6 +131,7 @@ def __init__(self,
130131
self.tp_size = get_tensor_model_parallel_world_size()
131132
self.tp_rank = get_tensor_model_parallel_rank()
132133
self.enable_fc = envs.VLLM_ASCEND_ENABLE_FLASHCOMM
134+
self.alt_stream = torch.npu.Stream()
133135
if self.enable_fc == 2:
134136
self.o_proj = ReplicatedLinear(
135137
self.total_num_heads * self.head_dim,
@@ -156,15 +158,17 @@ def forward(
156158
) -> torch.Tensor:
157159
qkv, _ = self.qkv_proj(hidden_states)
158160
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
159-
# Add qk-norm
160-
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
161-
self.head_dim)
162-
q_by_head = self.q_norm(q_by_head)
163-
q = q_by_head.view(q.shape)
161+
with npu_stream_switch(self.alt_stream):
162+
# Add qk-norm
163+
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
164+
self.head_dim)
165+
q_by_head = self.q_norm(q_by_head)
166+
q = q_by_head.view(q.shape)
164167
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
165168
self.head_dim)
166169
k_by_head = self.k_norm(k_by_head)
167170
k = k_by_head.view(k.shape)
171+
torch.npu.current_stream().wait_stream(self.alt_stream)
168172
q, k = self.rotary_emb(positions,
169173
q,
170174
k,

vllm_ascend/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,16 @@ def delete_torchair_cache_file():
454454
shutil.rmtree(torch_air_abs_path)
455455
except FileNotFoundError:
456456
pass
457+
458+
459+
def npu_stream_switch(target_stream: torch.npu.Stream,
460+
*,
461+
enabled: bool = True):
462+
"""
463+
Switch to the target stream if enabled is True.
464+
Otherwise, do nothing.
465+
"""
466+
if not enabled:
467+
return nullcontext()
468+
assert target_stream is not None
469+
return torch.npu.stream(target_stream)

0 commit comments

Comments
 (0)