9
9
MLAAttentionImpl )
10
10
from vllm .attention .backends .utils import PAD_SLOT_ID
11
11
from vllm .config import get_current_vllm_config
12
+ from vllm .forward_context import ForwardContext , get_forward_context
13
+ from vllm .utils import direct_register_custom_op
12
14
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
13
15
LinearBase , RowParallelLinear ,
14
16
UnquantizedLinearMethod )
@@ -669,130 +671,180 @@ def forward(
669
671
kv_cache : torch .Tensor ,
670
672
attn_metadata : M ,
671
673
output : Optional [torch .Tensor ] = None ,
674
+ trace_flag : bool = True ,
672
675
) -> torch .Tensor :
673
676
assert output is not None , "Output tensor must be provided."
674
- if attn_metadata is None :
675
- # Profiling run.
676
- return output
677
- self .running_in_graph = self .enable_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly
678
- num_actual_toks = attn_metadata .num_actual_tokens
679
- if k_pe is None and not self .running_in_graph :
680
- kv_c , k_pe = self .kv_a_proj_with_mqa (
681
- hidden_states_or_kv_c_normed )[0 ].split (
682
- [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
683
- kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
677
+ if trace_flag :
678
+ torch .ops .vllm .unified_ascend_mla_attention_with_output (
679
+ query = hidden_states_or_q_c ,
680
+ key = hidden_states_or_kv_c_normed ,
681
+ value = k_pe ,
682
+ output = output ,
683
+ layer_name = layer .layer_name )
684
684
else :
685
- kv_c_normed = hidden_states_or_kv_c_normed
686
- assert attn_metadata .num_decodes is not None and \
687
- attn_metadata .num_prefills is not None and \
688
- attn_metadata .num_decode_tokens is not None
689
- has_decode = attn_metadata .num_decodes > 0
690
- has_prefill = attn_metadata .num_prefills > 0
691
- num_decode_tokens = attn_metadata .num_decode_tokens
692
- if not self .running_in_graph :
693
- # Inputs and outputs may be padded for CUDA graphs
694
- output_padded = output
695
- output = output [:num_actual_toks , ...]
696
- kv_c_normed = kv_c_normed [:num_actual_toks , ...]
697
- prefill_k_c_normed = kv_c_normed [num_decode_tokens :]
698
- if not self .running_in_graph :
699
- hidden_states_or_q_c = hidden_states_or_q_c [:num_actual_toks , ...]
700
- decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
701
- prefill_hs_or_q_c = hidden_states_or_q_c [num_decode_tokens :]
702
- k_pe = k_pe [:num_actual_toks , ...]
703
- k_pe = k_pe .unsqueeze (1 )
704
- decode_k_pe = k_pe [:num_decode_tokens ]
705
- prefill_k_pe = k_pe [num_decode_tokens :]
706
- else :
707
- decode_hs_or_q_c = hidden_states_or_q_c
708
- if has_decode :
709
- decode_k_nope = None
710
- assert attn_metadata .decode is not None
711
- decode_ql_nope , decode_q_pe = \
712
- self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
713
- if self .running_in_graph :
714
- seq_len = self .rotary_emb .max_position_embeddings
715
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
716
- dtype = decode_q_pe .dtype )
717
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
718
- dtype = decode_q_pe .dtype )
719
- cos = cos [attn_metadata .decode .input_positions ]
720
- sin = sin [attn_metadata .decode .input_positions ]
721
- cos = cos [:, None , None , :]
722
- sin = sin [:, None , None , :]
723
- decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
724
- decode_k_pe , decode_k_nope = self .exec_kv (
725
- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
726
- attn_metadata .slot_mapping )
685
+ if attn_metadata is None :
686
+ # Profiling run.
687
+ return output
688
+ self .running_in_graph = self .enable_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly
689
+ num_actual_toks = attn_metadata .num_actual_tokens
690
+ if k_pe is None and not self .running_in_graph :
691
+ kv_c , k_pe = self .kv_a_proj_with_mqa (
692
+ hidden_states_or_kv_c_normed )[0 ].split (
693
+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
694
+ kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
695
+ else :
696
+ kv_c_normed = hidden_states_or_kv_c_normed
697
+ assert attn_metadata .num_decodes is not None and \
698
+ attn_metadata .num_prefills is not None and \
699
+ attn_metadata .num_decode_tokens is not None
700
+ has_decode = attn_metadata .num_decodes > 0
701
+ has_prefill = attn_metadata .num_prefills > 0
702
+ num_decode_tokens = attn_metadata .num_decode_tokens
703
+ if not self .running_in_graph :
704
+ # Inputs and outputs may be padded for CUDA graphs
705
+ output_padded = output
706
+ output = output [:num_actual_toks , ...]
707
+ kv_c_normed = kv_c_normed [:num_actual_toks , ...]
708
+ prefill_k_c_normed = kv_c_normed [num_decode_tokens :]
709
+ if not self .running_in_graph :
710
+ hidden_states_or_q_c = hidden_states_or_q_c [:num_actual_toks , ...]
711
+ decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
712
+ prefill_hs_or_q_c = hidden_states_or_q_c [num_decode_tokens :]
713
+ k_pe = k_pe [:num_actual_toks , ...]
714
+ k_pe = k_pe .unsqueeze (1 )
715
+ decode_k_pe = k_pe [:num_decode_tokens ]
716
+ prefill_k_pe = k_pe [num_decode_tokens :]
727
717
else :
728
- decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
729
- attn_metadata .decode .input_positions ,
730
- decode_q_pe .contiguous (),
731
- decode_k_pe ,
732
- max_seq_len = attn_metadata .decode .max_seq_lens )
733
- if has_prefill :
734
- assert attn_metadata .prefill is not None
735
- prefill_q = self .q_proj (prefill_hs_or_q_c )[0 ]\
736
- .view (- 1 , self .num_heads , self .qk_head_dim )
737
- prefill_q_pe = prefill_q [..., self .qk_nope_head_dim :]
738
- prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
718
+ decode_hs_or_q_c = hidden_states_or_q_c
719
+ if has_decode :
720
+ decode_k_nope = None
721
+ assert attn_metadata .decode is not None
722
+ decode_ql_nope , decode_q_pe = \
723
+ self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
724
+ if self .running_in_graph :
725
+ seq_len = self .rotary_emb .max_position_embeddings
726
+ cos = self .rotary_emb .cos_cached [:seq_len ].to (
727
+ dtype = decode_q_pe .dtype )
728
+ sin = self .rotary_emb .sin_cached [:seq_len ].to (
729
+ dtype = decode_q_pe .dtype )
730
+ cos = cos [attn_metadata .decode .input_positions ]
731
+ sin = sin [attn_metadata .decode .input_positions ]
732
+ cos = cos [:, None , None , :]
733
+ sin = sin [:, None , None , :]
734
+ decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
735
+ decode_k_pe , decode_k_nope = self .exec_kv (
736
+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
737
+ attn_metadata .slot_mapping )
738
+ else :
739
+ decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
740
+ attn_metadata .decode .input_positions ,
741
+ decode_q_pe .contiguous (),
742
+ decode_k_pe ,
743
+ max_seq_len = attn_metadata .decode .max_seq_lens )
744
+ if has_prefill :
745
+ assert attn_metadata .prefill is not None
746
+ prefill_q = self .q_proj (prefill_hs_or_q_c )[0 ]\
747
+ .view (- 1 , self .num_heads , self .qk_head_dim )
748
+ prefill_q_pe = prefill_q [..., self .qk_nope_head_dim :]
749
+ prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
750
+ if self .enable_graph_mode :
751
+ num_tokens = prefill_hs_or_q_c .shape [0 ]
752
+ prefill_k_pe = prefill_k_pe .view (num_tokens , self .num_kv_heads ,
753
+ - 1 )
754
+ if self .rotary_emb .__class__ .__name__ == 'RotaryEmbedding' :
755
+ # NOTE: When scaling not specified
756
+ ori_q_pe_shape , ori_k_pe_shape = prefill_q_pe .shape , prefill_k_pe .shape
757
+ prefill_q_pe = prefill_q_pe .reshape (num_tokens , - 1 )
758
+ prefill_k_pe = prefill_k_pe .reshape (num_tokens , - 1 )
759
+ prefill_q_pe , prefill_k_pe = self .rotary_emb (
760
+ attn_metadata .prefill .input_positions , prefill_q_pe ,
761
+ prefill_k_pe )
762
+ prefill_q_pe = prefill_q_pe .view (ori_q_pe_shape )
763
+ prefill_k_pe = prefill_k_pe .view (ori_k_pe_shape )
764
+ else :
765
+ prefill_q_pe , prefill_k_pe = self .rotary_emb (
766
+ attn_metadata .prefill .input_positions , prefill_q_pe ,
767
+ prefill_k_pe )
768
+ prefill_q = torch .cat ([prefill_q_nope , prefill_q_pe ], dim = - 1 )
769
+ else :
770
+ prefill_q_pe [...], prefill_k_pe [...] = self .rotary_emb (
771
+ attn_metadata .prefill .input_positions ,
772
+ prefill_q_pe .contiguous (),
773
+ prefill_k_pe ,
774
+ max_seq_len = attn_metadata .prefill .max_seq_lens )
739
775
if self .enable_graph_mode :
740
- num_tokens = prefill_hs_or_q_c .shape [0 ]
741
- prefill_k_pe = prefill_k_pe .view (num_tokens , self .num_kv_heads ,
742
- - 1 )
743
- if self .rotary_emb .__class__ .__name__ == 'RotaryEmbedding' :
744
- # NOTE: When scaling not specified
745
- ori_q_pe_shape , ori_k_pe_shape = prefill_q_pe .shape , prefill_k_pe .shape
746
- prefill_q_pe = prefill_q_pe .reshape (num_tokens , - 1 )
747
- prefill_k_pe = prefill_k_pe .reshape (num_tokens , - 1 )
748
- prefill_q_pe , prefill_k_pe = self .rotary_emb (
749
- attn_metadata .prefill .input_positions , prefill_q_pe ,
750
- prefill_k_pe )
751
- prefill_q_pe = prefill_q_pe .view (ori_q_pe_shape )
752
- prefill_k_pe = prefill_k_pe .view (ori_k_pe_shape )
776
+ if len (kv_cache ) > 0 and kv_cache [0 ].numel (
777
+ ) > 0 and attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
778
+ slots = attn_metadata .slot_mapping
779
+ # NOTE: Separate the kv cache in advance to avoid OOM or other issues
780
+ torch_npu ._npu_reshape_and_cache (key = kv_c_normed .view (
781
+ num_tokens , self .num_kv_heads , - 1 ),
782
+ value = prefill_k_pe ,
783
+ key_cache = kv_cache [0 ],
784
+ value_cache = kv_cache [1 ],
785
+ slot_indices = slots )
786
+ elif kv_cache .numel () > 0 :
787
+ key = torch .cat ([
788
+ kv_c_normed .view ([num_actual_toks , self .num_kv_heads , - 1 ]),
789
+ k_pe
790
+ ],
791
+ dim = 2 )
792
+ torch_npu ._npu_reshape_and_cache_siso (
793
+ key = key ,
794
+ key_cache = kv_cache ,
795
+ slot_indices = attn_metadata .slot_mapping .flatten ())
796
+ if has_prefill :
797
+ output [num_decode_tokens :] = self ._forward_prefill (
798
+ prefill_q , prefill_k_c_normed , prefill_k_pe , kv_cache ,
799
+ attn_metadata )
800
+ if has_decode :
801
+ if self .running_in_graph :
802
+ return self ._forward_decode (decode_ql_nope , decode_q_pe ,
803
+ decode_k_nope , decode_k_pe ,
804
+ kv_cache , attn_metadata )
753
805
else :
754
- prefill_q_pe , prefill_k_pe = self .rotary_emb (
755
- attn_metadata . prefill . input_positions , prefill_q_pe ,
756
- prefill_k_pe )
757
- prefill_q = torch . cat ([ prefill_q_nope , prefill_q_pe ], dim = - 1 )
758
- else :
759
- prefill_q_pe [...], prefill_k_pe [...] = self . rotary_emb (
760
- attn_metadata . prefill . input_positions ,
761
- prefill_q_pe . contiguous () ,
762
- prefill_k_pe ,
763
- max_seq_len = attn_metadata . prefill . max_seq_lens )
764
- if self . enable_graph_mode :
765
- if len ( kv_cache ) > 0 and kv_cache [ 0 ]. numel (
766
- ) > 0 and attn_metadata . attn_state == AscendAttentionState . PrefillNoCache :
767
- slots = attn_metadata . slot_mapping
768
- # NOTE: Separate the kv cache in advance to avoid OOM or other issues
769
- torch_npu . _npu_reshape_and_cache ( key = kv_c_normed . view (
770
- num_tokens , self .num_kv_heads , - 1 ),
771
- value = prefill_k_pe ,
772
- key_cache = kv_cache [ 0 ] ,
773
- value_cache = kv_cache [ 1 ] ,
774
- slot_indices = slots )
775
- elif kv_cache . numel () > 0 :
776
- key = torch . cat ([
777
- kv_c_normed . view ([ num_actual_toks , self . num_kv_heads , - 1 ]) ,
778
- k_pe
779
- ],
780
- dim = 2 )
781
- torch_npu . _npu_reshape_and_cache_siso (
782
- key = key ,
783
- key_cache = kv_cache ,
784
- slot_indices = attn_metadata . slot_mapping . flatten ())
785
- if has_prefill :
786
- output [ num_decode_tokens :] = self . _forward_prefill (
787
- prefill_q , prefill_k_c_normed , prefill_k_pe , kv_cache ,
788
- attn_metadata )
789
- if has_decode :
790
- if self . running_in_graph :
791
- return self . _forward_decode ( decode_ql_nope , decode_q_pe ,
792
- decode_k_nope , decode_k_pe ,
793
- kv_cache , attn_metadata )
794
- else :
795
- output [: num_decode_tokens ] = self . _forward_decode (
796
- decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe ,
797
- kv_cache , attn_metadata )
798
- return output_padded
806
+ output [: num_decode_tokens ] = self ._forward_decode (
807
+ decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe ,
808
+ kv_cache , attn_metadata )
809
+ return output_padded
810
+
811
+
812
+ def unified_ascend_mla_attention_with_output (
813
+ query : torch . Tensor ,
814
+ key : torch . Tensor ,
815
+ value : torch . Tensor ,
816
+ output : torch . Tensor ,
817
+ layer_name : str ,
818
+ ) -> None :
819
+ forward_context : ForwardContext = get_forward_context ()
820
+ attn_metadata = forward_context . attn_metadata
821
+ self = forward_context . no_compile_layers [ layer_name ]
822
+ kv_cache = self .kv_cache [ forward_context . virtual_engine ]
823
+ self . impl . forward ( self ,
824
+ query ,
825
+ key ,
826
+ value ,
827
+ kv_cache ,
828
+ attn_metadata ,
829
+ output ,
830
+ trace_flag = False )
831
+ return
832
+
833
+
834
+ def unified_mla_attention_with_output_fake (
835
+ query : torch . Tensor ,
836
+ key : torch . Tensor ,
837
+ value : torch . Tensor ,
838
+ output : torch . Tensor ,
839
+ layer_name : str ,
840
+ ) -> None :
841
+ return
842
+
843
+
844
+ direct_register_custom_op (
845
+ op_name = "unified_ascend_mla_attention_with_output" ,
846
+ op_func = unified_ascend_mla_attention_with_output ,
847
+ mutates_args = [ "output" ],
848
+ fake_impl = unified_mla_attention_with_output_fake ,
849
+ dispatch_key = "PrivateUse1" ,
850
+ )
0 commit comments