6
6
import torch .nn .functional as F
7
7
from transformers import Qwen2Config
8
8
9
+ from vllm .attention import AttentionType
9
10
from vllm .compilation .decorators import support_torch_compile
10
11
from vllm .config import CacheConfig , VllmConfig
11
12
from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
13
+ from vllm .model_executor .layers .layernorm import RMSNorm
12
14
from vllm .model_executor .layers .logits_processor import LogitsProcessor
13
15
from vllm .model_executor .layers .quantization import QuantizationConfig
14
16
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
18
20
from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
19
21
from vllm .model_executor .models .utils import (AutoWeightsLoader , PPMissingLayer , maybe_prefix )
20
22
21
- from vllm .model_executor .models .qwen2 import Qwen2Model , Qwen2DecoderLayer
23
+ from vllm .model_executor .models .qwen2 import Qwen2Model , Qwen2Attention , Qwen2MLP
22
24
from vllm .distributed import (
23
25
get_pp_group ,
24
26
get_tensor_model_parallel_world_size ,
@@ -49,19 +51,102 @@ def maybe_pad_and_reduce_scatter(
49
51
hidden_states = tensor_model_parallel_reduce_scatter (hidden_states , 0 )
50
52
return hidden_states
51
53
52
- class CustomQwen2DecoderLayer (Qwen2DecoderLayer ):
54
+
55
+ class CustomQwen2Attention (Qwen2Attention ):
53
56
54
57
def __init__ (
55
58
self ,
56
- config : Qwen2Config ,
59
+ hidden_size : int ,
60
+ num_heads : int ,
61
+ num_kv_heads : int ,
62
+ max_position : int = 4096 * 32 ,
63
+ rope_theta : float = 10000 ,
57
64
cache_config : Optional [CacheConfig ] = None ,
58
65
quant_config : Optional [QuantizationConfig ] = None ,
66
+ rope_scaling : Optional [tuple ] = None ,
59
67
prefix : str = "" ,
68
+ attn_type : str = AttentionType .DECODER ,
69
+ dual_chunk_attention_config : Optional [dict [str , Any ]] = None ,
60
70
) -> None :
61
- super ().__init__ (config = config ,
71
+ super ().__init__ (hidden_size = hidden_size ,
72
+ num_heads = num_heads ,
73
+ num_kv_heads = num_kv_heads ,
74
+ max_position = max_position ,
75
+ rope_theta = rope_theta ,
62
76
cache_config = cache_config ,
63
77
quant_config = quant_config ,
64
- prefix = prefix )
78
+ rope_scaling = rope_scaling ,
79
+ prefix = prefix ,
80
+ attn_type = attn_type ,
81
+ dual_chunk_attention_config = dual_chunk_attention_config )
82
+
83
+ def forward (
84
+ self ,
85
+ positions : torch .Tensor ,
86
+ hidden_states : torch .Tensor ,
87
+ cos : torch .Tensor ,
88
+ sin : torch .Tensor
89
+ ) -> torch .Tensor :
90
+ qkv , _ = self .qkv_proj (hidden_states )
91
+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
92
+ q , k = self .rotary_emb (positions , q , k , cos = cos , sin = sin , skip_index_select = True )
93
+ attn_output = self .attn (q , k , v )
94
+ output , _ = self .o_proj (attn_output )
95
+ return output
96
+
97
+
98
+ class CustomQwen2DecoderLayer (nn .Module ):
99
+
100
+ def __init__ (
101
+ self ,
102
+ config : Qwen2Config ,
103
+ cache_config : Optional [CacheConfig ] = None ,
104
+ quant_config : Optional [QuantizationConfig ] = None ,
105
+ prefix : str = "" ,
106
+ ) -> None :
107
+ super ().__init__ ()
108
+ self .hidden_size = config .hidden_size
109
+ # Requires transformers > 4.32.0
110
+ rope_theta = getattr (config , "rope_theta" , 1000000 )
111
+ rope_scaling = getattr (config , "rope_scaling" , None )
112
+ dual_chunk_attention_config = getattr (config ,
113
+ "dual_chunk_attention_config" ,
114
+ None )
115
+
116
+ # By default, Qwen2 uses causal attention as it is a decoder-only model.
117
+ # You can override the HF config with `is_causal=False` to enable
118
+ # bidirectional attention, which is used in some embedding models
119
+ # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
120
+ if getattr (config , "is_causal" , True ):
121
+ attn_type = AttentionType .DECODER
122
+ else :
123
+ attn_type = AttentionType .ENCODER_ONLY
124
+
125
+ self .self_attn = CustomQwen2Attention (
126
+ hidden_size = self .hidden_size ,
127
+ num_heads = config .num_attention_heads ,
128
+ max_position = config .max_position_embeddings ,
129
+ num_kv_heads = config .num_key_value_heads ,
130
+ rope_theta = rope_theta ,
131
+ cache_config = cache_config ,
132
+ quant_config = quant_config ,
133
+ rope_scaling = rope_scaling ,
134
+ prefix = f"{ prefix } .self_attn" ,
135
+ attn_type = attn_type ,
136
+ dual_chunk_attention_config = dual_chunk_attention_config ,
137
+ )
138
+ self .mlp = Qwen2MLP (
139
+ hidden_size = self .hidden_size ,
140
+ intermediate_size = config .intermediate_size ,
141
+ hidden_act = config .hidden_act ,
142
+ quant_config = quant_config ,
143
+ prefix = f"{ prefix } .mlp" ,
144
+ )
145
+ self .input_layernorm = RMSNorm (config .hidden_size ,
146
+ eps = config .rms_norm_eps )
147
+ self .post_attention_layernorm = RMSNorm (config .hidden_size ,
148
+ eps = config .rms_norm_eps )
149
+
65
150
self .tp_rank = get_tensor_model_parallel_rank ()
66
151
self .tp_size = get_tensor_model_parallel_world_size ()
67
152
self .self_attn .reduce_results = False
@@ -72,6 +157,8 @@ def forward(
72
157
positions : torch .Tensor ,
73
158
hidden_states : torch .Tensor ,
74
159
residual : Optional [torch .Tensor ],
160
+ cos ,
161
+ sin ,
75
162
fc_enabled : bool ,
76
163
pad_size : int ,
77
164
) -> tuple [torch .Tensor , torch .Tensor ]:
@@ -91,6 +178,8 @@ def forward(
91
178
hidden_states = self .self_attn (
92
179
positions = positions ,
93
180
hidden_states = hidden_states ,
181
+ cos = cos ,
182
+ sin = sin
94
183
)
95
184
if fc_enabled :
96
185
hidden_states = maybe_pad_and_reduce_scatter (hidden_states , pad_size )
@@ -129,6 +218,7 @@ def __init__(self,
129
218
prefix = prefix ,
130
219
decoder_layer_type = decoder_layer_type )
131
220
self .tp_size = get_tensor_model_parallel_world_size ()
221
+ self .cos_sin_cache = self .layers [0 ].self_attn .rotary_emb .cos_sin_cache
132
222
133
223
def forward (
134
224
self ,
@@ -157,11 +247,19 @@ def forward(
157
247
if fc_enabled :
158
248
num_tokens = hidden_states .size (0 )
159
249
pad_size = (self .tp_size - (num_tokens % self .tp_size )) % self .tp_size
250
+
251
+ cos_sin = self .cos_sin_cache .index_select (0 , positions )
252
+ last_dim = cos_sin .size ()[- 1 ]
253
+ cos , sin = cos_sin .reshape (- 1 ,2 , last_dim // 2 ).repeat (1 , 1 , 2 ).chunk (2 , dim = - 2 )
254
+ cos , sin = cos .view (1 , - 1 , 1 , last_dim ).contiguous (), sin .view (1 , - 1 , 1 , last_dim ).contiguous ()
255
+
160
256
for layer in self .layers [self .start_layer :self .end_layer ]:
161
257
hidden_states , residual = layer (
162
258
positions ,
163
259
hidden_states ,
164
260
residual ,
261
+ cos ,
262
+ sin ,
165
263
fc_enabled ,
166
264
pad_size ,
167
265
)
0 commit comments