Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions vllm_ascend/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.

from typing import Optional, Union
from typing import Any, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -50,6 +50,7 @@
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
from vllm_ascend.utils import npu_stream_switch


class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
Expand Down Expand Up @@ -125,6 +126,58 @@ def forward(
return hidden_states


class CustomQwen3MoeAttention(Qwen3MoeAttention):

def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(hidden_size, num_heads, num_kv_heads, rope_theta,
rope_scaling, max_position_embeddings, head_dim,
rms_norm_eps, qkv_bias, cache_config, quant_config,
prefix, dual_chunk_attention_config)
self.alt_stream = torch.npu.Stream()

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

self.alt_stream.wait_stream(torch.npu.current_stream())
with npu_stream_switch(self.alt_stream):
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)

k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)

torch.npu.current_stream().wait_stream(self.alt_stream)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
Comment on lines +153 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method has an incorrect signature and implementation, which will lead to a runtime error and breaks KV caching.

  1. Incorrect Signature: The method signature is missing kv_cache: torch.Tensor and attn_metadata: AttentionMetadata parameters. These are required by the underlying attention mechanism for caching.
  2. Incorrect self.attn call: The call self.attn(q, k, v) on line 176 is missing arguments. The vllm.model_executor.layers.attention.Attention layer (self.attn) expects kv_cache and attn_metadata. This will cause a TypeError at runtime.
  3. Disabled KV Caching: Removing these parameters disables KV caching, which is a critical performance feature for LLM serving. It would force re-computation of attention over the entire sequence for each token, leading to a severe performance degradation during decoding.

To fix this, the forward method signature must be corrected to accept kv_cache and attn_metadata, and these must be passed to the self.attn call. You will also need to ensure AttentionMetadata is imported, e.g., from vllm.model_executor.layers.attention import AttentionMetadata.

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: "AttentionMetadata",
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        self.alt_stream.wait_stream(torch.npu.current_stream())
        with npu_stream_switch(self.alt_stream):
            # Add qk-norm
            q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
                               self.head_dim)
            q_by_head = self.q_norm(q_by_head)
            q = q_by_head.view(q.shape)

        k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
                           self.head_dim)
        k_by_head = self.k_norm(k_by_head)
        k = k_by_head.view(k.shape)

        torch.npu.current_stream().wait_stream(self.alt_stream)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output



class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):

def __init__(
Expand All @@ -142,7 +195,7 @@ def __init__(
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Qwen3MoeAttention(
self.self_attn = CustomQwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
Expand Down
Loading