Skip to content

Conversation

weijinqian0
Copy link
Collaborator

@weijinqian0 weijinqian0 commented Sep 20, 2025

[Feature] qwen3_moe qk norm support multi_stream.

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a CustomQwen3MoeAttention class to implement multi-stream execution for QK normalization, aiming for performance optimization. While the multi-stream logic itself is sound, the implementation has a critical flaw in the forward method of the new attention class. It incorrectly changes the method signature by removing kv_cache and attn_metadata, which will cause a runtime crash and disables the essential KV caching mechanism. My review provides a fix for this critical issue to ensure correctness and prevent a severe performance regression.

Comment on lines +153 to +178
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

self.alt_stream.wait_stream(torch.npu.current_stream())
with npu_stream_switch(self.alt_stream):
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)

k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)

torch.npu.current_stream().wait_stream(self.alt_stream)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method has an incorrect signature and implementation, which will lead to a runtime error and breaks KV caching.

  1. Incorrect Signature: The method signature is missing kv_cache: torch.Tensor and attn_metadata: AttentionMetadata parameters. These are required by the underlying attention mechanism for caching.
  2. Incorrect self.attn call: The call self.attn(q, k, v) on line 176 is missing arguments. The vllm.model_executor.layers.attention.Attention layer (self.attn) expects kv_cache and attn_metadata. This will cause a TypeError at runtime.
  3. Disabled KV Caching: Removing these parameters disables KV caching, which is a critical performance feature for LLM serving. It would force re-computation of attention over the entire sequence for each token, leading to a severe performance degradation during decoding.

To fix this, the forward method signature must be corrected to accept kv_cache and attn_metadata, and these must be passed to the self.attn call. You will also need to ensure AttentionMetadata is imported, e.g., from vllm.model_executor.layers.attention import AttentionMetadata.

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: "AttentionMetadata",
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        self.alt_stream.wait_stream(torch.npu.current_stream())
        with npu_stream_switch(self.alt_stream):
            # Add qk-norm
            q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
                               self.head_dim)
            q_by_head = self.q_norm(q_by_head)
            q = q_by_head.view(q.shape)

        k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
                           self.head_dim)
        k_by_head = self.k_norm(k_by_head)
        k = k_by_head.view(k.shape)

        torch.npu.current_stream().wait_stream(self.alt_stream)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant