Skip to content

Commit 7c4cad1

Browse files
committed
qwen2 optimize rope
Signed-off-by: David9857 <985700846@qq.com>
1 parent 137810d commit 7c4cad1

File tree

1 file changed

+103
-5
lines changed

1 file changed

+103
-5
lines changed

vllm_ascend/models/qwen2.py

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import torch.nn.functional as F
77
from transformers import Qwen2Config
88

9+
from vllm.attention import AttentionType
910
from vllm.compilation.decorators import support_torch_compile
1011
from vllm.config import CacheConfig, VllmConfig
1112
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
13+
from vllm.model_executor.layers.layernorm import RMSNorm
1214
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1315
from vllm.model_executor.layers.quantization import QuantizationConfig
1416
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -18,7 +20,7 @@
1820
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
1921
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix)
2022

21-
from vllm.model_executor.models.qwen2 import Qwen2Model, Qwen2DecoderLayer
23+
from vllm.model_executor.models.qwen2 import Qwen2Model, Qwen2Attention, Qwen2MLP
2224
from vllm.distributed import (
2325
get_pp_group,
2426
get_tensor_model_parallel_world_size,
@@ -49,19 +51,102 @@ def maybe_pad_and_reduce_scatter(
4951
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
5052
return hidden_states
5153

52-
class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
54+
55+
class CustomQwen2Attention(Qwen2Attention):
5356

5457
def __init__(
5558
self,
56-
config: Qwen2Config,
59+
hidden_size: int,
60+
num_heads: int,
61+
num_kv_heads: int,
62+
max_position: int = 4096 * 32,
63+
rope_theta: float = 10000,
5764
cache_config: Optional[CacheConfig] = None,
5865
quant_config: Optional[QuantizationConfig] = None,
66+
rope_scaling: Optional[tuple] = None,
5967
prefix: str = "",
68+
attn_type: str = AttentionType.DECODER,
69+
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
6070
) -> None:
61-
super().__init__(config=config,
71+
super().__init__(hidden_size=hidden_size,
72+
num_heads=num_heads,
73+
num_kv_heads=num_kv_heads,
74+
max_position=max_position,
75+
rope_theta=rope_theta,
6276
cache_config=cache_config,
6377
quant_config=quant_config,
64-
prefix=prefix)
78+
rope_scaling=rope_scaling,
79+
prefix=prefix,
80+
attn_type=attn_type,
81+
dual_chunk_attention_config=dual_chunk_attention_config)
82+
83+
def forward(
84+
self,
85+
positions: torch.Tensor,
86+
hidden_states: torch.Tensor,
87+
cos: torch.Tensor,
88+
sin: torch.Tensor
89+
) -> torch.Tensor:
90+
qkv, _ = self.qkv_proj(hidden_states)
91+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
92+
q, k = self.rotary_emb(positions, q, k, cos=cos, sin=sin, skip_index_select=True)
93+
attn_output = self.attn(q, k, v)
94+
output, _ = self.o_proj(attn_output)
95+
return output
96+
97+
98+
class CustomQwen2DecoderLayer(nn.Module):
99+
100+
def __init__(
101+
self,
102+
config: Qwen2Config,
103+
cache_config: Optional[CacheConfig] = None,
104+
quant_config: Optional[QuantizationConfig] = None,
105+
prefix: str = "",
106+
) -> None:
107+
super().__init__()
108+
self.hidden_size = config.hidden_size
109+
# Requires transformers > 4.32.0
110+
rope_theta = getattr(config, "rope_theta", 1000000)
111+
rope_scaling = getattr(config, "rope_scaling", None)
112+
dual_chunk_attention_config = getattr(config,
113+
"dual_chunk_attention_config",
114+
None)
115+
116+
# By default, Qwen2 uses causal attention as it is a decoder-only model.
117+
# You can override the HF config with `is_causal=False` to enable
118+
# bidirectional attention, which is used in some embedding models
119+
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
120+
if getattr(config, "is_causal", True):
121+
attn_type = AttentionType.DECODER
122+
else:
123+
attn_type = AttentionType.ENCODER_ONLY
124+
125+
self.self_attn = CustomQwen2Attention(
126+
hidden_size=self.hidden_size,
127+
num_heads=config.num_attention_heads,
128+
max_position=config.max_position_embeddings,
129+
num_kv_heads=config.num_key_value_heads,
130+
rope_theta=rope_theta,
131+
cache_config=cache_config,
132+
quant_config=quant_config,
133+
rope_scaling=rope_scaling,
134+
prefix=f"{prefix}.self_attn",
135+
attn_type=attn_type,
136+
dual_chunk_attention_config=dual_chunk_attention_config,
137+
)
138+
self.mlp = Qwen2MLP(
139+
hidden_size=self.hidden_size,
140+
intermediate_size=config.intermediate_size,
141+
hidden_act=config.hidden_act,
142+
quant_config=quant_config,
143+
prefix=f"{prefix}.mlp",
144+
)
145+
self.input_layernorm = RMSNorm(config.hidden_size,
146+
eps=config.rms_norm_eps)
147+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
148+
eps=config.rms_norm_eps)
149+
65150
self.tp_rank = get_tensor_model_parallel_rank()
66151
self.tp_size = get_tensor_model_parallel_world_size()
67152
self.self_attn.reduce_results=False
@@ -72,6 +157,8 @@ def forward(
72157
positions: torch.Tensor,
73158
hidden_states: torch.Tensor,
74159
residual: Optional[torch.Tensor],
160+
cos,
161+
sin,
75162
fc_enabled: bool,
76163
pad_size: int,
77164
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -91,6 +178,8 @@ def forward(
91178
hidden_states = self.self_attn(
92179
positions=positions,
93180
hidden_states=hidden_states,
181+
cos=cos,
182+
sin=sin
94183
)
95184
if fc_enabled:
96185
hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size)
@@ -129,6 +218,7 @@ def __init__(self,
129218
prefix=prefix,
130219
decoder_layer_type=decoder_layer_type)
131220
self.tp_size = get_tensor_model_parallel_world_size()
221+
self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache
132222

133223
def forward(
134224
self,
@@ -157,11 +247,19 @@ def forward(
157247
if fc_enabled:
158248
num_tokens = hidden_states.size(0)
159249
pad_size = (self.tp_size - (num_tokens % self.tp_size)) % self.tp_size
250+
251+
cos_sin = self.cos_sin_cache.index_select(0, positions)
252+
last_dim = cos_sin.size()[-1]
253+
cos, sin = cos_sin.reshape(-1 ,2, last_dim // 2).repeat(1, 1, 2).chunk(2, dim=-2)
254+
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(1, -1, 1, last_dim).contiguous()
255+
160256
for layer in self.layers[self.start_layer:self.end_layer]:
161257
hidden_states, residual = layer(
162258
positions,
163259
hidden_states,
164260
residual,
261+
cos,
262+
sin,
165263
fc_enabled,
166264
pad_size,
167265
)

0 commit comments

Comments
 (0)