Skip to content

[V0.9.1] optimize rope in qwen3 #1719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tests/singlecard/ops/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
import torch
import torch.nn as nn
import torch_npu

from vllm_ascend.utils import enable_custom_op

Expand Down Expand Up @@ -198,3 +199,69 @@ def test_rotary_embedding_quant_with_leading_dim(
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)


# test npu_apply_rotary_pos_emb with head_size=128 and rotary_dim=128 and is_neox_style=True
@pytest.mark.parametrize("is_neox_style", [True])
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("rotary_dim", [128])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_npu_apply_rotary_pos_emb_with_head_size_equals_rotary_dim(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size

torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
rope = rope.to(dtype=dtype)
num_tokens = batch_size * seq_len
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
qkv_tensor = torch.randn(1,
num_tokens,
num_heads,
head_size * 3,
dtype=dtype)
query, key, _ = qkv_tensor.split(
[head_size, head_size, head_size],
dim=-1,
)

ref_query, ref_key = rope.forward_native(positions, query, key)
cos_sin = rope.cos_sin_cache.index_select(0, positions)
last_dim = cos_sin.size()[-1]
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(1, 1,
2).chunk(2, dim=-2)
# BSNH
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
1, -1, 1, last_dim).contiguous()
torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)

# Compare the results.
torch.testing.assert_close(query.view(ref_query.size()),
ref_query,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(key.view(ref_key.size()),
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
219 changes: 196 additions & 23 deletions vllm_ascend/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import torch
from torch import nn
from transformers import Qwen3Config
from vllm.attention import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
from vllm.model_executor.models.qwen3 import Qwen3Attention, Qwen3MLP
from vllm.model_executor.models.utils import (AutoWeightsLoader,
PPMissingLayer, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand All @@ -21,7 +23,66 @@
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant


class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
class CustomQwen3Attention(Qwen3Attention):

def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
super().__init__(hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position=max_position,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
qkv_bias=qkv_bias,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=prefix,
attn_type=attn_type)

def forward(
self,
positions: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions,
q,
k,
cos=cos,
sin=sin,
skip_index_select=True)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output


class CustomQwen3DecoderLayer(nn.Module):

def __init__(
self,
Expand All @@ -30,31 +91,99 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)

# By default, Qwen3 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY

self.self_attn = CustomQwen3Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = Qwen3MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
if quant_config is None:
return
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
else:
from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod

from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
assert isinstance(quant_config, AscendQuantConfig), \
"Expected quant_config to be an instance of AscendQuantConfig"

assert isinstance(quant_config, AscendQuantConfig), \
"Expected quant_config to be an instance of AscendQuantConfig"
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
AscendW8A8LinearMethod):
self.input_layernorm = AddRMSNormW8A8Quant(
config.hidden_size,
layer=self.self_attn.qkv_proj,
eps=config.rms_norm_eps)
else:
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
AscendW8A8LinearMethod):
self.post_attention_layernorm = AddRMSNormW8A8Quant(
config.hidden_size,
layer=self.mlp.gate_up_proj,
eps=config.rms_norm_eps)
else:
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps)

if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
AscendW8A8LinearMethod):
self.input_layernorm = AddRMSNormW8A8Quant(
config.hidden_size,
layer=self.self_attn.qkv_proj,
eps=config.rms_norm_eps)
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
AscendW8A8LinearMethod):
self.post_attention_layernorm = AddRMSNormW8A8Quant(
config.hidden_size,
layer=self.mlp.gate_up_proj,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
cos=cos,
sin=sin,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual


ALL_DECODER_LAYER_TYPES = {
Expand All @@ -77,6 +206,50 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=CustomQwen3DecoderLayer)
self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

cos_sin = self.cos_sin_cache.index_select(0, positions)
last_dim = cos_sin.size()[-1]
cos, sin = cos_sin.reshape(-1, 2,
last_dim // 2).repeat(1, 1, 2).chunk(2,
dim=-2)
# BSNH
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
1, -1, 1, last_dim).contiguous()

for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
cos,
sin,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Expand Down
47 changes: 29 additions & 18 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ def custom_rotary_embedding_enabled(query, neox_style, head_size):


def rope_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
cos: torch.Tensor = None,
sin: torch.Tensor = None,
is_neox_style_override: Optional[bool] = None,
skip_index_select: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
import torch_npu
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
Expand All @@ -62,17 +64,26 @@ def rope_forward_oot(
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
if skip_index_select and neox_style and self.head_size == self.rotary_dim:
# TODO: Remove the contiguous in the future.
# BSNH
query = query.contiguous().view(1, query.shape[0], -1,
self.head_size)
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
# requires head_size=128 and neox_style=True
torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
return query.view(query_shape), key.view(key_shape)


Expand Down