72
72
from vllm_ascend .quantization .quant_config import AscendLinearMethod
73
73
from vllm_ascend .quantization .w8a8_dynamic import AscendW8A8DynamicLinearMethod
74
74
from vllm_ascend .utils import dispose_tensor
75
+ from vllm_ascend .attention .attention_v1 import AscendAttentionState
75
76
76
77
VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
77
78
@@ -502,7 +503,9 @@ def forward(
502
503
positions : torch .Tensor ,
503
504
hidden_states : torch .Tensor ,
504
505
kv_cache : Optional [torch .Tensor ] = None ,
505
- attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
506
+ attn_metadata : Optional [AttentionMetadata ] = None ,
507
+ rotary_cos : Optional [torch .Tensor ] = None ,
508
+ rotary_sin : Optional [torch .Tensor ] = None ) -> torch .Tensor :
506
509
if self .q_lora_rank is not None :
507
510
ckq = self .q_a_proj (hidden_states )[0 ]
508
511
hidden_states_or_q_c = self .q_a_layernorm (ckq )
@@ -516,6 +519,8 @@ def forward(
516
519
dtype = hidden_states_or_q_c .dtype ,
517
520
device = hidden_states_or_q_c .device )
518
521
forward_kwargs ['output' ] = output
522
+ forward_kwargs ['rotary_cos' ] = rotary_cos
523
+ forward_kwargs ['rotary_sin' ] = rotary_sin
519
524
520
525
output = self .mla_attn .impl .forward (self .mla_attn ,
521
526
hidden_states_or_q_c ,
@@ -607,6 +612,8 @@ def forward(
607
612
residual : Optional [torch .Tensor ],
608
613
kv_cache : Optional [torch .Tensor ] = None ,
609
614
attn_metadata : Optional [AttentionMetadata ] = None ,
615
+ rotary_cos : Optional [torch .Tensor ] = None ,
616
+ rotary_sin : Optional [torch .Tensor ] = None ,
610
617
) -> torch .Tensor :
611
618
# Self Attention
612
619
if residual is None :
@@ -626,6 +633,8 @@ def forward(
626
633
hidden_states = hidden_states ,
627
634
kv_cache = kv_cache ,
628
635
attn_metadata = attn_metadata ,
636
+ rotary_cos = rotary_cos ,
637
+ rotary_sin = rotary_sin ,
629
638
)
630
639
631
640
if hidden_states .dtype == torch .float16 :
@@ -703,9 +712,43 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
703
712
make_empty_intermediate_tensors_factory (
704
713
["hidden_states" , "residual" ], config .hidden_size ))
705
714
715
+ ascend_config = get_ascend_config ()
716
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
717
+
718
+ rope_theta = getattr (config , "rope_theta" , 10000 )
719
+ rope_scaling = getattr (config , "rope_scaling" , None )
720
+ max_position_embeddings = getattr (config , "max_position_embeddings" ,
721
+ 8192 )
722
+ if rope_scaling :
723
+ rope_scaling ["rope_type" ] = 'deepseek_yarn'
724
+ self .rotary_emb = get_rope (config .qk_rope_head_dim ,
725
+ rotary_dim = config .qk_rope_head_dim ,
726
+ max_position = max_position_embeddings ,
727
+ base = rope_theta ,
728
+ rope_scaling = rope_scaling ,
729
+ is_neox_style = False )
730
+
706
731
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
707
732
return self .embed_tokens (input_ids )
708
733
734
+ def prepare_decoder_rotary_cos_sin (
735
+ self , attn_metadata : Optional [AttentionMetadata ] = None
736
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
737
+ if (attn_metadata is not None and attn_metadata .num_decodes is not None and
738
+ attn_metadata .atten_state ):
739
+ has_decode = attn_metadata .num_decodes > 0
740
+ running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
741
+ AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding ]
742
+ if has_decode and running_in_graph :
743
+ cos = self .rotary_emb .cos_cached
744
+ sin = self .rotary_emb .sin_cached
745
+ cos = cos [attn_metadata .decode .input_positions ]
746
+ sin = sin [attn_metadata .decode .input_positions ]
747
+ cos = cos [:, None , None , :]
748
+ sin = sin [:, None , None , :]
749
+ return cos , sin
750
+ return None , None
751
+
709
752
def forward (
710
753
self ,
711
754
input_ids : torch .Tensor ,
@@ -726,13 +769,17 @@ def forward(
726
769
hidden_states = intermediate_tensors ["hidden_states" ]
727
770
residual = intermediate_tensors ["residual" ]
728
771
772
+ # In graph mode and v1 engine,
773
+ # precomputing cos and sin can eliminate repeated calculations in each decode layer.
774
+ rotary_cos , rotary_sin = self .prepare_decoder_rotary_cos_sin (attn_metadata )
775
+
729
776
for i in range (self .start_layer , self .end_layer ):
730
777
layer = self .layers [i ]
731
778
hidden_states , residual = layer (
732
779
positions , hidden_states , residual ,
733
780
kv_caches [i -
734
781
self .start_layer ] if kv_caches is not None else None ,
735
- attn_metadata )
782
+ attn_metadata , rotary_cos , rotary_sin )
736
783
737
784
if not get_pp_group ().is_last_rank :
738
785
return IntermediateTensors ({
0 commit comments