67
67
68
68
import vllm_ascend .envs as envs_ascend
69
69
from vllm_ascend .ascend_config import get_ascend_config
70
+ from vllm_ascend .attention .attention_v1 import AscendAttentionState
70
71
from vllm_ascend .distributed .parallel_state import get_ep_group
71
72
from vllm_ascend .ops .fused_moe import AscendFusedMoE
72
73
from vllm_ascend .quantization .quant_config import AscendLinearMethod
@@ -500,12 +501,13 @@ def __init__(
500
501
self .enable_multistream_mla = \
501
502
ascend_config .torchair_graph_config .enable_multistream_mla
502
503
503
- def forward (
504
- self ,
505
- positions : torch .Tensor ,
506
- hidden_states : torch .Tensor ,
507
- kv_cache : Optional [torch .Tensor ] = None ,
508
- attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
504
+ def forward (self ,
505
+ positions : torch .Tensor ,
506
+ hidden_states : torch .Tensor ,
507
+ kv_cache : Optional [torch .Tensor ] = None ,
508
+ attn_metadata : Optional [AttentionMetadata ] = None ,
509
+ rotary_cos : Optional [torch .Tensor ] = None ,
510
+ rotary_sin : Optional [torch .Tensor ] = None ) -> torch .Tensor :
509
511
if self .q_lora_rank is not None :
510
512
ckq = self .q_a_proj (hidden_states )[0 ]
511
513
use_multistream_mla = (self .enable_multistream_mla
@@ -526,6 +528,8 @@ def forward(
526
528
dtype = hidden_states_or_q_c .dtype ,
527
529
device = hidden_states_or_q_c .device )
528
530
forward_kwargs ['output' ] = output
531
+ forward_kwargs ['rotary_cos' ] = rotary_cos
532
+ forward_kwargs ['rotary_sin' ] = rotary_sin
529
533
530
534
output = self .mla_attn .impl .forward (self .mla_attn ,
531
535
hidden_states_or_q_c ,
@@ -617,6 +621,8 @@ def forward(
617
621
residual : Optional [torch .Tensor ],
618
622
kv_cache : Optional [torch .Tensor ] = None ,
619
623
attn_metadata : Optional [AttentionMetadata ] = None ,
624
+ rotary_cos : Optional [torch .Tensor ] = None ,
625
+ rotary_sin : Optional [torch .Tensor ] = None ,
620
626
) -> torch .Tensor :
621
627
# Self Attention
622
628
if residual is None :
@@ -636,6 +642,8 @@ def forward(
636
642
hidden_states = hidden_states ,
637
643
kv_cache = kv_cache ,
638
644
attn_metadata = attn_metadata ,
645
+ rotary_cos = rotary_cos ,
646
+ rotary_sin = rotary_sin ,
639
647
)
640
648
641
649
if hidden_states .dtype == torch .float16 :
@@ -713,9 +721,47 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
713
721
make_empty_intermediate_tensors_factory (
714
722
["hidden_states" , "residual" ], config .hidden_size ))
715
723
724
+ ascend_config = get_ascend_config ()
725
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
726
+
727
+ rope_theta = getattr (config , "rope_theta" , 10000 )
728
+ rope_scaling = getattr (config , "rope_scaling" , None )
729
+ max_position_embeddings = getattr (config , "max_position_embeddings" ,
730
+ 8192 )
731
+ if rope_scaling :
732
+ rope_scaling ["rope_type" ] = 'deepseek_yarn'
733
+ self .rotary_emb = get_rope (config .qk_rope_head_dim ,
734
+ rotary_dim = config .qk_rope_head_dim ,
735
+ max_position = max_position_embeddings ,
736
+ base = rope_theta ,
737
+ rope_scaling = rope_scaling ,
738
+ is_neox_style = False )
739
+
716
740
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
717
741
return self .embed_tokens (input_ids )
718
742
743
+ def prepare_decoder_rotary_cos_sin (
744
+ self ,
745
+ attn_metadata : Optional [AttentionMetadata ] = None
746
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
747
+ if (envs .VLLM_USE_V1 and attn_metadata is not None
748
+ and attn_metadata .num_decodes is not None
749
+ and attn_metadata .atten_state is not None ):
750
+ has_decode = attn_metadata .num_decodes > 0
751
+ running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
752
+ AscendAttentionState .DecodeOnly ,
753
+ AscendAttentionState .SpecDecoding
754
+ ]
755
+ if has_decode and running_in_graph :
756
+ cos = self .rotary_emb .cos_cached
757
+ sin = self .rotary_emb .sin_cached
758
+ cos = cos [attn_metadata .decode .input_positions ]
759
+ sin = sin [attn_metadata .decode .input_positions ]
760
+ cos = cos [:, None , None , :]
761
+ sin = sin [:, None , None , :]
762
+ return cos , sin
763
+ return None , None
764
+
719
765
def forward (
720
766
self ,
721
767
input_ids : torch .Tensor ,
@@ -736,13 +782,18 @@ def forward(
736
782
hidden_states = intermediate_tensors ["hidden_states" ]
737
783
residual = intermediate_tensors ["residual" ]
738
784
785
+ # In graph mode and v1 engine,
786
+ # precomputing cos and sin can eliminate repeated calculations in each decode layer.
787
+ rotary_cos , rotary_sin = self .prepare_decoder_rotary_cos_sin (
788
+ attn_metadata )
789
+
739
790
for i in range (self .start_layer , self .end_layer ):
740
791
layer = self .layers [i ]
741
792
hidden_states , residual = layer (
742
793
positions , hidden_states , residual ,
743
794
kv_caches [i -
744
795
self .start_layer ] if kv_caches is not None else None ,
745
- attn_metadata )
796
+ attn_metadata , rotary_cos , rotary_sin )
746
797
747
798
if not get_pp_group ().is_last_rank :
748
799
return IntermediateTensors ({
0 commit comments