@@ -408,6 +408,10 @@ def __init__(
408
408
self .enable_graph_mode = additional_config .get (
409
409
"enable_graph_mode" , False )
410
410
411
+ self .cos = None
412
+ self .sin = None
413
+ self .debug_layer_idx = extra_impl_args .get ('debug_layer_idx' , 0 )
414
+
411
415
def _v_up_proj_and_o_proj (self , x ):
412
416
# Convert from (B, N, L) to (N, B, L)
413
417
x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
@@ -700,18 +704,20 @@ def forward(
700
704
decode_ql_nope , decode_q_pe = \
701
705
self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
702
706
if self .running_in_graph :
703
- seq_len = self .rotary_emb .max_position_embeddings
704
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
705
- dtype = decode_q_pe .dtype )
706
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
707
- dtype = decode_q_pe .dtype )
708
- cos = cos [attn_metadata .decode .input_positions ]
709
- sin = sin [attn_metadata .decode .input_positions ]
710
- cos = cos [:, None , None , :]
711
- sin = sin [:, None , None , :]
712
- decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
707
+ # During the autoregressive decoding process, the cos and sin values are exactly the same for each layer
708
+ if self .debug_layer_idx == 0 or self .cos is None or self .sin is None :
709
+ seq_len = self .rotary_emb .max_position_embeddings
710
+ self .cos = self .rotary_emb .cos_cached [:seq_len ].to (
711
+ dtype = decode_q_pe .dtype )
712
+ self .sin = self .rotary_emb .sin_cached [:seq_len ].to (
713
+ dtype = decode_q_pe .dtype )
714
+ self .cos = self .cos [attn_metadata .decode .input_positions ]
715
+ self .sin = self .sin [attn_metadata .decode .input_positions ]
716
+ self .cos = self .cos [:, None , None , :]
717
+ self .sin = self .sin [:, None , None , :]
718
+ decode_q_pe = self .rope_single (decode_q_pe , self .cos , self .sin )
713
719
decode_k_pe , decode_k_nope = self .exec_kv (
714
- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
720
+ hidden_states_or_kv_c_normed , self . cos , self . sin , kv_cache ,
715
721
attn_metadata .slot_mapping )
716
722
else :
717
723
decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
0 commit comments