-
Notifications
You must be signed in to change notification settings - Fork 463
[Feature] qwen3_moe qk norm support multi_stream. #3059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a CustomQwen3MoeAttention
class to implement multi-stream execution for QK normalization, aiming for performance optimization. While the multi-stream logic itself is sound, the implementation has a critical flaw in the forward
method of the new attention class. It incorrectly changes the method signature by removing kv_cache
and attn_metadata
, which will cause a runtime crash and disables the essential KV caching mechanism. My review provides a fix for this critical issue to ensure correctness and prevent a severe performance regression.
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
) -> torch.Tensor: | ||
qkv, _ = self.qkv_proj(hidden_states) | ||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | ||
|
||
self.alt_stream.wait_stream(torch.npu.current_stream()) | ||
with npu_stream_switch(self.alt_stream): | ||
# Add qk-norm | ||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, | ||
self.head_dim) | ||
q_by_head = self.q_norm(q_by_head) | ||
q = q_by_head.view(q.shape) | ||
|
||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, | ||
self.head_dim) | ||
k_by_head = self.k_norm(k_by_head) | ||
k = k_by_head.view(k.shape) | ||
|
||
torch.npu.current_stream().wait_stream(self.alt_stream) | ||
q, k = self.rotary_emb(positions, q, k) | ||
attn_output = self.attn(q, k, v) | ||
output, _ = self.o_proj(attn_output) | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward
method has an incorrect signature and implementation, which will lead to a runtime error and breaks KV caching.
- Incorrect Signature: The method signature is missing
kv_cache: torch.Tensor
andattn_metadata: AttentionMetadata
parameters. These are required by the underlying attention mechanism for caching. - Incorrect
self.attn
call: The callself.attn(q, k, v)
on line 176 is missing arguments. Thevllm.model_executor.layers.attention.Attention
layer (self.attn
) expectskv_cache
andattn_metadata
. This will cause aTypeError
at runtime. - Disabled KV Caching: Removing these parameters disables KV caching, which is a critical performance feature for LLM serving. It would force re-computation of attention over the entire sequence for each token, leading to a severe performance degradation during decoding.
To fix this, the forward
method signature must be corrected to accept kv_cache
and attn_metadata
, and these must be passed to the self.attn
call. You will also need to ensure AttentionMetadata
is imported, e.g., from vllm.model_executor.layers.attention import AttentionMetadata
.
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: "AttentionMetadata",
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
self.alt_stream.wait_stream(torch.npu.current_stream())
with npu_stream_switch(self.alt_stream):
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
torch.npu.current_stream().wait_stream(self.alt_stream)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
[Feature] qwen3_moe qk norm support multi_stream.