diff --git a/tests/singlecard/ops/test_rotary_embedding.py b/tests/singlecard/ops/test_rotary_embedding.py index a3504a88b2..eb2bbecaab 100644 --- a/tests/singlecard/ops/test_rotary_embedding.py +++ b/tests/singlecard/ops/test_rotary_embedding.py @@ -9,6 +9,7 @@ import pytest import torch import torch.nn as nn +import torch_npu from vllm_ascend.utils import enable_custom_op @@ -198,3 +199,69 @@ def test_rotary_embedding_quant_with_leading_dim( ref_key, atol=DEFAULT_ATOL, rtol=DEFAULT_RTOL) + + +# test npu_apply_rotary_pos_emb with head_size=128 and rotary_dim=128 and is_neox_style=True +@pytest.mark.parametrize("is_neox_style", [True]) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("rotary_dim", [128]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_npu_apply_rotary_pos_emb_with_head_size_equals_rotary_dim( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + if rotary_dim is None: + rotary_dim = head_size + + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + rope = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, dtype) + rope = rope.to(dtype=dtype) + num_tokens = batch_size * seq_len + positions = torch.randint(0, max_position, (batch_size * seq_len, )) + qkv_tensor = torch.randn(1, + num_tokens, + num_heads, + head_size * 3, + dtype=dtype) + query, key, _ = qkv_tensor.split( + [head_size, head_size, head_size], + dim=-1, + ) + + ref_query, ref_key = rope.forward_native(positions, query, key) + cos_sin = rope.cos_sin_cache.index_select(0, positions) + last_dim = cos_sin.size()[-1] + cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(1, 1, + 2).chunk(2, dim=-2) + # BSNH + cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view( + 1, -1, 1, last_dim).contiguous() + torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin) + + # Compare the results. + torch.testing.assert_close(query.view(ref_query.size()), + ref_query, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + torch.testing.assert_close(key.view(ref_key.size()), + ref_key, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) diff --git a/vllm_ascend/models/qwen3.py b/vllm_ascend/models/qwen3.py index a05106f228..b82287e4a8 100644 --- a/vllm_ascend/models/qwen3.py +++ b/vllm_ascend/models/qwen3.py @@ -4,15 +4,17 @@ import torch from torch import nn from transformers import Qwen3Config +from vllm.attention import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer +from vllm.model_executor.models.qwen3 import Qwen3Attention, Qwen3MLP from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -21,7 +23,66 @@ from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant -class CustomQwen3DecoderLayer(Qwen3DecoderLayer): +class CustomQwen3Attention(Qwen3Attention): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: + super().__init__(hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + qkv_bias=qkv_bias, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=prefix, + attn_type=attn_type) + + def forward( + self, + positions: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, + q, + k, + cos=cos, + sin=sin, + skip_index_select=True) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CustomQwen3DecoderLayer(nn.Module): def __init__( self, @@ -30,31 +91,99 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, Qwen3 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = CustomQwen3Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = Qwen3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) if quant_config is None: - return + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + else: + from vllm_ascend.quantization.quant_config import AscendQuantConfig + from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod - from vllm_ascend.quantization.quant_config import AscendQuantConfig - from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod + assert isinstance(quant_config, AscendQuantConfig), \ + "Expected quant_config to be an instance of AscendQuantConfig" - assert isinstance(quant_config, AscendQuantConfig), \ - "Expected quant_config to be an instance of AscendQuantConfig" + if isinstance(self.self_attn.qkv_proj.quant_method.quant_method, + AscendW8A8LinearMethod): + self.input_layernorm = AddRMSNormW8A8Quant( + config.hidden_size, + layer=self.self_attn.qkv_proj, + eps=config.rms_norm_eps) + else: + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + if isinstance(self.mlp.gate_up_proj.quant_method.quant_method, + AscendW8A8LinearMethod): + self.post_attention_layernorm = AddRMSNormW8A8Quant( + config.hidden_size, + layer=self.mlp.gate_up_proj, + eps=config.rms_norm_eps) + else: + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) - if isinstance(self.self_attn.qkv_proj.quant_method.quant_method, - AscendW8A8LinearMethod): - self.input_layernorm = AddRMSNormW8A8Quant( - config.hidden_size, - layer=self.self_attn.qkv_proj, - eps=config.rms_norm_eps) - if isinstance(self.mlp.gate_up_proj.quant_method.quant_method, - AscendW8A8LinearMethod): - self.post_attention_layernorm = AddRMSNormW8A8Quant( - config.hidden_size, - layer=self.mlp.gate_up_proj, - eps=config.rms_norm_eps) + def forward( + self, + positions: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + cos=cos, + sin=sin, + hidden_states=hidden_states, + ) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual ALL_DECODER_LAYER_TYPES = { @@ -77,6 +206,50 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix, decoder_layer_type=CustomQwen3DecoderLayer) + self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + cos_sin = self.cos_sin_cache.index_select(0, positions) + last_dim = cos_sin.size()[-1] + cos, sin = cos_sin.reshape(-1, 2, + last_dim // 2).repeat(1, 1, 2).chunk(2, + dim=-2) + # BSNH + cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view( + 1, -1, 1, last_dim).contiguous() + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + cos, + sin, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index f55ab8e0cb..c4c568d4d0 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -31,13 +31,15 @@ def custom_rotary_embedding_enabled(query, neox_style, head_size): def rope_forward_oot( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - is_neox_style_override: Optional[bool] = None -) -> Tuple[torch.Tensor, torch.Tensor]: + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + cos: torch.Tensor = None, + sin: torch.Tensor = None, + is_neox_style_override: Optional[bool] = None, + skip_index_select: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: import torch_npu query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: @@ -62,17 +64,26 @@ def rope_forward_oot( raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: - # TODO: Remove the contiguous in the future. - query = query.contiguous().view(query.shape[0], -1) - key = key.contiguous().view(key.shape[0], -1) - torch_npu._npu_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - neox_style, - ) + if skip_index_select and neox_style and self.head_size == self.rotary_dim: + # TODO: Remove the contiguous in the future. + # BSNH + query = query.contiguous().view(1, query.shape[0], -1, + self.head_size) + key = key.contiguous().view(1, key.shape[0], -1, self.head_size) + # requires head_size=128 and neox_style=True + torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin) + else: + # TODO: Remove the contiguous in the future. + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(key.shape[0], -1) + torch_npu._npu_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + neox_style, + ) return query.view(query_shape), key.view(key_shape)