@@ -1000,6 +1000,10 @@ def __init__(
1000
1000
self .w_kc = None
1001
1001
self .w_vc = None
1002
1002
1003
+ self .cos = None
1004
+ self .sin = None
1005
+ self .debug_layer_idx = extra_impl_args .get ('debug_layer_idx' , 0 )
1006
+
1003
1007
self .enable_graph_mode = False
1004
1008
additional_config = get_current_vllm_config ().additional_config
1005
1009
if additional_config :
@@ -1128,17 +1132,18 @@ def forward(
1128
1132
q_nope , q_pe = q .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ],
1129
1133
dim = - 1 )
1130
1134
if k_pe is None and attn_metadata .decode_metadata :
1131
- seq_len = self .rotary_emb .max_position_embeddings
1132
-
1133
- cos = self .rotary_emb .cos_cached [:seq_len ].to (dtype = q_pe .dtype )
1134
- sin = self .rotary_emb .sin_cached [:seq_len ].to (dtype = q_pe .dtype )
1135
- cos = cos [attn_metadata .input_positions ]
1136
- sin = sin [attn_metadata .input_positions ]
1137
- cos = cos [:, None , None , :]
1138
- sin = sin [:, None , None , :]
1139
-
1140
- q_pe = self .rope_single (q_pe , cos , sin )
1141
- k_pe , k_nope = self .exec_kv (hidden_states_or_kv_c_normed , cos , sin ,
1135
+ if self .debug_layer_idx == 0 or self .cos is None or self .sin is None :
1136
+ seq_len = self .rotary_emb .max_position_embeddings
1137
+
1138
+ self .cos = self .rotary_emb .cos_cached [:seq_len ].to (dtype = q_pe .dtype )
1139
+ self .sin = self .rotary_emb .sin_cached [:seq_len ].to (dtype = q_pe .dtype )
1140
+ self .cos = self .cos [attn_metadata .input_positions ]
1141
+ self .sin = self .sin [attn_metadata .input_positions ]
1142
+ self .cos = self .cos [:, None , None , :]
1143
+ self .sin = self .sin [:, None , None , :]
1144
+
1145
+ q_pe = self .rope_single (q_pe , self .cos , self .sin )
1146
+ k_pe , k_nope = self .exec_kv (hidden_states_or_kv_c_normed , self .cos , self .sin ,
1142
1147
kv_cache , attn_metadata .slot_mapping )
1143
1148
else :
1144
1149
if k_pe is None :
0 commit comments