@@ -458,6 +458,10 @@ def __init__(
458
458
self .enable_graph_mode = additional_config .get (
459
459
"enable_graph_mode" , False )
460
460
461
+ self .cos = None
462
+ self .sin = None
463
+ self .debug_layer_idx = kwargs .get ('debug_layer_idx' , 0 )
464
+
461
465
def _v_up_proj_and_o_proj (self , x ):
462
466
# Convert from (B, N, L) to (N, B, L)
463
467
x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
@@ -750,19 +754,20 @@ def forward(
750
754
decode_ql_nope , decode_q_pe = \
751
755
self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
752
756
if self .running_in_graph :
753
- seq_len = self .rotary_emb .max_position_embeddings
754
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
755
- dtype = decode_q_pe .dtype )
756
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
757
- dtype = decode_q_pe .dtype )
758
- cos = cos [attn_metadata .decode .input_positions ]
759
- sin = sin [attn_metadata .decode .input_positions ]
760
- cos = cos [:, None , None , :]
761
- sin = sin [:, None , None , :]
762
-
763
- decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
757
+ # During the autoregressive decoding process, the cos and sin values are exactly the same for each layer
758
+ if self .debug_layer_idx == 0 or self .cos is None or self .sin is None :
759
+ seq_len = self .rotary_emb .max_position_embeddings
760
+ self .cos = self .rotary_emb .cos_cached [:seq_len ].to (
761
+ dtype = decode_q_pe .dtype )
762
+ self .sin = self .rotary_emb .sin_cached [:seq_len ].to (
763
+ dtype = decode_q_pe .dtype )
764
+ self .cos = self .cos [attn_metadata .decode .input_positions ]
765
+ self .sin = self .sin [attn_metadata .decode .input_positions ]
766
+ self .cos = self .cos [:, None , None , :]
767
+ self .sin = self .sin [:, None , None , :]
768
+ decode_q_pe = self .rope_single (decode_q_pe , self .cos , self .sin )
764
769
decode_k_pe , decode_k_nope = self .exec_kv (
765
- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
770
+ hidden_states_or_kv_c_normed , self . cos , self . sin , kv_cache ,
766
771
attn_metadata .slot_mapping )
767
772
else :
768
773
decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
0 commit comments