Skip to content

Commit efc563b

Browse files
committed
optimize rope in qwen3
Signed-off-by: David9857 <985700846@qq.com> use npu_apply_rotary_pos_emb when head_size is 128 and is noex_style Signed-off-by: David9857 <985700846@qq.com>
1 parent 9ca9c6f commit efc563b

File tree

2 files changed

+203
-24
lines changed

2 files changed

+203
-24
lines changed

vllm_ascend/models/qwen3.py

Lines changed: 174 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44
import torch
55
from torch import nn
66
from transformers import Qwen3Config
7+
from vllm.attention import AttentionType
78
from vllm.compilation.decorators import support_torch_compile
89
from vllm.config import CacheConfig, VllmConfig
910
from vllm.distributed import get_pp_group
11+
from vllm.model_executor.layers.layernorm import RMSNorm
1012
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1113
from vllm.model_executor.layers.quantization import QuantizationConfig
1214
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
1315
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
1416
from vllm.model_executor.models.qwen2 import Qwen2Model
15-
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
17+
from vllm.model_executor.models.qwen3 import Qwen3Attention, Qwen3MLP
1618
from vllm.model_executor.models.utils import (AutoWeightsLoader,
1719
PPMissingLayer, maybe_prefix)
1820
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -21,7 +23,66 @@
2123
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
2224

2325

24-
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
26+
class CustomQwen3Attention(Qwen3Attention):
27+
28+
def __init__(self,
29+
hidden_size: int,
30+
num_heads: int,
31+
num_kv_heads: int,
32+
max_position: int = 4096 * 32,
33+
head_dim: Optional[int] = None,
34+
rms_norm_eps: float = 1e-06,
35+
qkv_bias: bool = False,
36+
rope_theta: float = 10000,
37+
cache_config: Optional[CacheConfig] = None,
38+
quant_config: Optional[QuantizationConfig] = None,
39+
rope_scaling: Optional[tuple] = None,
40+
prefix: str = "",
41+
attn_type: str = AttentionType.DECODER) -> None:
42+
super().__init__(hidden_size=hidden_size,
43+
num_heads=num_heads,
44+
num_kv_heads=num_kv_heads,
45+
max_position=max_position,
46+
head_dim=head_dim,
47+
rms_norm_eps=rms_norm_eps,
48+
qkv_bias=qkv_bias,
49+
rope_theta=rope_theta,
50+
cache_config=cache_config,
51+
quant_config=quant_config,
52+
rope_scaling=rope_scaling,
53+
prefix=prefix,
54+
attn_type=attn_type)
55+
56+
def forward(
57+
self,
58+
positions: torch.Tensor,
59+
cos: torch.Tensor,
60+
sin: torch.Tensor,
61+
hidden_states: torch.Tensor,
62+
) -> torch.Tensor:
63+
qkv, _ = self.qkv_proj(hidden_states)
64+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
65+
# Add qk-norm
66+
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
67+
self.head_dim)
68+
q_by_head = self.q_norm(q_by_head)
69+
q = q_by_head.view(q.shape)
70+
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
71+
self.head_dim)
72+
k_by_head = self.k_norm(k_by_head)
73+
k = k_by_head.view(k.shape)
74+
q, k = self.rotary_emb(positions,
75+
q,
76+
k,
77+
cos=cos,
78+
sin=sin,
79+
skip_index_select=True)
80+
attn_output = self.attn(q, k, v)
81+
output, _ = self.o_proj(attn_output)
82+
return output
83+
84+
85+
class CustomQwen3DecoderLayer(nn.Module):
2586

2687
def __init__(
2788
self,
@@ -30,11 +91,48 @@ def __init__(
3091
quant_config: Optional[QuantizationConfig] = None,
3192
prefix: str = "",
3293
) -> None:
33-
super().__init__(config=config,
34-
cache_config=cache_config,
35-
quant_config=quant_config,
36-
prefix=prefix)
94+
super().__init__()
95+
self.hidden_size = config.hidden_size
96+
# Requires transformers > 4.32.0
97+
rope_theta = getattr(config, "rope_theta", 1000000)
98+
rope_scaling = getattr(config, "rope_scaling", None)
99+
100+
# By default, Qwen3 uses causal attention as it is a decoder-only model.
101+
# You can override the HF config with `is_causal=False` to enable
102+
# bidirectional attention, which is used in some embedding models
103+
# (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
104+
if getattr(config, "is_causal", True):
105+
attn_type = AttentionType.DECODER
106+
else:
107+
attn_type = AttentionType.ENCODER_ONLY
108+
109+
self.self_attn = CustomQwen3Attention(
110+
hidden_size=self.hidden_size,
111+
num_heads=config.num_attention_heads,
112+
max_position=config.max_position_embeddings,
113+
num_kv_heads=config.num_key_value_heads,
114+
rope_theta=rope_theta,
115+
rms_norm_eps=config.rms_norm_eps,
116+
qkv_bias=getattr(config, 'attention_bias', False),
117+
head_dim=getattr(config, 'head_dim', None),
118+
cache_config=cache_config,
119+
quant_config=quant_config,
120+
rope_scaling=rope_scaling,
121+
prefix=f"{prefix}.self_attn",
122+
attn_type=attn_type,
123+
)
124+
self.mlp = Qwen3MLP(
125+
hidden_size=self.hidden_size,
126+
intermediate_size=config.intermediate_size,
127+
hidden_act=config.hidden_act,
128+
quant_config=quant_config,
129+
prefix=f"{prefix}.mlp",
130+
)
37131
if quant_config is None:
132+
self.input_layernorm = RMSNorm(config.hidden_size,
133+
eps=config.rms_norm_eps)
134+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
135+
eps=config.rms_norm_eps)
38136
return
39137

40138
from vllm_ascend.quantization.quant_config import AscendQuantConfig
@@ -56,6 +154,32 @@ def __init__(
56154
layer=self.mlp.gate_up_proj,
57155
eps=config.rms_norm_eps)
58156

157+
def forward(
158+
self,
159+
positions: torch.Tensor,
160+
cos: torch.Tensor,
161+
sin: torch.Tensor,
162+
hidden_states: torch.Tensor,
163+
residual: Optional[torch.Tensor],
164+
) -> tuple[torch.Tensor, torch.Tensor]:
165+
# Self Attention
166+
if residual is None:
167+
residual = hidden_states
168+
hidden_states = self.input_layernorm(hidden_states)
169+
else:
170+
hidden_states, residual = self.input_layernorm(
171+
hidden_states, residual)
172+
hidden_states = self.self_attn(
173+
positions=positions,
174+
cos=cos,
175+
sin=sin,
176+
hidden_states=hidden_states,
177+
)
178+
hidden_states, residual = self.post_attention_layernorm(
179+
hidden_states, residual)
180+
hidden_states = self.mlp(hidden_states)
181+
return hidden_states, residual
182+
59183

60184
ALL_DECODER_LAYER_TYPES = {
61185
"attention": CustomQwen3DecoderLayer,
@@ -77,6 +201,50 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
77201
super().__init__(vllm_config=vllm_config,
78202
prefix=prefix,
79203
decoder_layer_type=CustomQwen3DecoderLayer)
204+
self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache
205+
206+
def forward(
207+
self,
208+
input_ids: torch.Tensor,
209+
positions: torch.Tensor,
210+
intermediate_tensors: Optional[IntermediateTensors] = None,
211+
inputs_embeds: Optional[torch.Tensor] = None,
212+
) -> Union[torch.Tensor, IntermediateTensors]:
213+
if get_pp_group().is_first_rank:
214+
if inputs_embeds is not None:
215+
hidden_states = inputs_embeds
216+
else:
217+
hidden_states = self.get_input_embeddings(input_ids)
218+
residual = None
219+
else:
220+
assert intermediate_tensors is not None
221+
hidden_states = intermediate_tensors["hidden_states"]
222+
residual = intermediate_tensors["residual"]
223+
224+
cos_sin = self.cos_sin_cache.index_select(0, positions)
225+
last_dim = cos_sin.size()[-1]
226+
cos, sin = cos_sin.reshape(-1, 2,
227+
last_dim // 2).repeat(1, 1, 2).chunk(2,
228+
dim=-2)
229+
# BSNH
230+
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
231+
1, -1, 1, last_dim).contiguous()
232+
233+
for layer in self.layers[self.start_layer:self.end_layer]:
234+
hidden_states, residual = layer(
235+
positions,
236+
cos,
237+
sin,
238+
hidden_states,
239+
residual,
240+
)
241+
if not get_pp_group().is_last_rank:
242+
return IntermediateTensors({
243+
"hidden_states": hidden_states,
244+
"residual": residual
245+
})
246+
hidden_states, _ = self.norm(hidden_states, residual)
247+
return hidden_states
80248

81249

82250
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):

vllm_ascend/ops/rotary_embedding.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@ def custom_rotary_embedding_enabled(query, neox_style, head_size):
3131

3232

3333
def rope_forward_oot(
34-
self,
35-
positions: torch.Tensor,
36-
query: torch.Tensor,
37-
key: torch.Tensor,
38-
offsets: Optional[torch.Tensor] = None,
39-
is_neox_style_override: Optional[bool] = None
40-
) -> Tuple[torch.Tensor, torch.Tensor]:
34+
self,
35+
positions: torch.Tensor,
36+
query: torch.Tensor,
37+
key: torch.Tensor,
38+
offsets: Optional[torch.Tensor] = None,
39+
cos: torch.Tensor = None,
40+
sin: torch.Tensor = None,
41+
is_neox_style_override: Optional[bool] = None,
42+
skip_index_select: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
4143
import torch_npu
4244
query_shape, key_shape = query.shape, key.shape
4345
if self.cos_sin_cache.device != query.device:
@@ -62,17 +64,26 @@ def rope_forward_oot(
6264
raise NotImplementedError(
6365
"Batched rotary embedding is currently not supported on NPU.")
6466
else:
65-
# TODO: Remove the contiguous in the future.
66-
query = query.contiguous().view(query.shape[0], -1)
67-
key = key.contiguous().view(key.shape[0], -1)
68-
torch_npu._npu_rotary_embedding(
69-
positions,
70-
query,
71-
key,
72-
self.head_size,
73-
self.cos_sin_cache,
74-
neox_style,
75-
)
67+
if skip_index_select and neox_style:
68+
# TODO: Remove the contiguous in the future.
69+
# BSNH
70+
query = query.contiguous().view(1, query.shape[0], -1,
71+
self.head_size)
72+
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
73+
# requires head_size=128 and neox_style=True
74+
torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
75+
else:
76+
# TODO: Remove the contiguous in the future.
77+
query = query.contiguous().view(query.shape[0], -1)
78+
key = key.contiguous().view(key.shape[0], -1)
79+
torch_npu._npu_rotary_embedding(
80+
positions,
81+
query,
82+
key,
83+
self.head_size,
84+
self.cos_sin_cache,
85+
neox_style,
86+
)
7687
return query.view(query_shape), key.view(key_shape)
7788

7889

0 commit comments

Comments
 (0)