20
20
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
21
21
from vllm_ascend .multistream .context import get_multistream_comm_context
22
22
from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
23
- from vllm_ascend .torchair .utils import (TorchairCommonAttentionMetadata , npu_stream_switch , npu_wait_tensor )
23
+ from vllm_ascend .torchair .utils import (TorchairCommonAttentionMetadata ,
24
+ npu_stream_switch , npu_wait_tensor )
24
25
from vllm_ascend .utils import npu_prefetch
25
26
from vllm_ascend .worker .npu_input_batch import InputBatch
26
27
@@ -668,35 +669,35 @@ def _forward_prefill(
668
669
dtype = q_nope .dtype ,
669
670
device = q_nope .device )
670
671
attn_lse = torch .empty (self .num_heads ,
671
- num_tokens ,
672
- dtype = torch .float32 ,
673
- device = q_nope .device )
672
+ num_tokens ,
673
+ dtype = torch .float32 ,
674
+ device = q_nope .device )
674
675
self .prefill_mask = torch .triu (
675
676
torch .ones (512 , 512 , device = q_nope .device , dtype = q_nope .dtype ),
676
677
1 ) # 512: mask only support 512
677
678
if attn_metadata .num_prefills > 1 :
678
- self .prefill_mask = self .prefill_mask .unsqueeze (0 ).repeat (attn_metadata . num_prefills , 1 ,
679
- 1 )
680
- torch_npu .atb .npu_ring_mla (
681
- q_nope = q_nope ,
682
- q_rope = q_pe ,
683
- k_nope = k_nope ,
684
- k_rope = k_pe ,
685
- value = value ,
686
- mask = self . prefill_mask ,
687
- seqlen = torch . tensor ( attn_metadata .prefill .query_lens ,
688
- dtype = torch .int32 ),
689
- head_num = self .num_heads ,
690
- kv_head_num = self .num_heads ,
691
- pre_out = None ,
692
- prev_lse = None ,
693
- qk_scale = self .scale ,
694
- kernel_type = "kernel_type_high_precision" ,
695
- mask_type = "mask_type_triu" ,
696
- input_layout = "type_bsnd" ,
697
- calc_type = "calc_type_first_ring" ,
698
- output = attn_output ,
699
- softmax_lse = attn_lse )
679
+ self .prefill_mask = self .prefill_mask .unsqueeze (0 ).repeat (
680
+ attn_metadata . num_prefills , 1 , 1 )
681
+ torch_npu .atb .npu_ring_mla (q_nope = q_nope ,
682
+ q_rope = q_pe ,
683
+ k_nope = k_nope ,
684
+ k_rope = k_pe ,
685
+ value = value ,
686
+ mask = self . prefill_mask ,
687
+ seqlen = torch . tensor (
688
+ attn_metadata .prefill .query_lens ,
689
+ dtype = torch .int32 ),
690
+ head_num = self .num_heads ,
691
+ kv_head_num = self .num_heads ,
692
+ pre_out = None ,
693
+ prev_lse = None ,
694
+ qk_scale = self .scale ,
695
+ kernel_type = "kernel_type_high_precision" ,
696
+ mask_type = "mask_type_triu" ,
697
+ input_layout = "type_bsnd" ,
698
+ calc_type = "calc_type_first_ring" ,
699
+ output = attn_output ,
700
+ softmax_lse = attn_lse )
700
701
attn_output , attn_lse = self ._compute_prefill_context ( \
701
702
q_nope , q_pe , kv_c_and_k_pe_cache , self .qk_rope_head_dim , attn_metadata , attn_output , attn_lse )
702
703
@@ -716,7 +717,8 @@ def exec_kv_decode(
716
717
N = self .num_kv_heads
717
718
S = 1
718
719
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
719
- kv_no_split = kv_no_split .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
720
+ kv_no_split = kv_no_split .view (
721
+ B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
720
722
cache_mode = "PA_NZ" if self .enable_kv_nz else "PA"
721
723
k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
722
724
kv_no_split ,
@@ -743,7 +745,8 @@ def exec_kv_prefill(
743
745
N = self .num_kv_heads
744
746
S = 1
745
747
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
746
- kv_no_split = kv_no_split .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
748
+ kv_no_split = kv_no_split .view (
749
+ B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
747
750
cache_mode = "PA_BLK_NZ" if self .enable_kv_nz else "PA"
748
751
_ , _ , k_pe , k_nope = torch_npu .npu_kv_rmsnorm_rope_cache (
749
752
kv_no_split ,
@@ -788,15 +791,15 @@ def _forward_decode(
788
791
actual_seq_lengths = None
789
792
if self .enable_kv_nz :
790
793
k_nope = k_nope .view (- 1 , self .num_kv_heads ,
791
- self .kv_lora_rank // 16 , block_size , 16 )
794
+ self .kv_lora_rank // 16 , block_size , 16 )
792
795
k_pe = k_pe .view (- 1 , self .num_kv_heads ,
793
- self .qk_rope_head_dim // 16 , block_size , 16 )
796
+ self .qk_rope_head_dim // 16 , block_size , 16 )
794
797
input_layout = "BSND"
795
798
else :
796
799
k_nope = k_nope .view (- 1 , self .num_kv_heads , block_size ,
797
- self .kv_lora_rank )
800
+ self .kv_lora_rank )
798
801
k_pe = k_pe .view (- 1 , self .num_kv_heads , block_size ,
799
- self .qk_rope_head_dim )
802
+ self .qk_rope_head_dim )
800
803
input_layout = "BNSD"
801
804
802
805
if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
@@ -846,13 +849,8 @@ def _forward_decode(
846
849
current_ms_metadata .before_comm_event .wait ()
847
850
return self ._v_up_proj (attn_output )
848
851
849
- def _mla_preprocess (
850
- self ,
851
- hidden_states ,
852
- kv_cache ,
853
- attn_metadata ,
854
- need_gather_q_kv
855
- ):
852
+ def _mla_preprocess (self , hidden_states , kv_cache , attn_metadata ,
853
+ need_gather_q_kv ):
856
854
# MLA Preprocess:
857
855
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
858
856
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
@@ -877,8 +875,7 @@ def _mla_preprocess(
877
875
kv_no_split = self .kv_a_proj_with_mqa (hidden_states )[0 ]
878
876
# Process for shared_expert_dp
879
877
if need_gather_q_kv :
880
- q_c = get_tp_group ().all_gather (
881
- q_c , 0 )
878
+ q_c = get_tp_group ().all_gather (q_c , 0 )
882
879
kv_no_split = get_tp_group ().all_gather (kv_no_split , 0 )
883
880
decode_preprocess_res = None
884
881
prefill_preprocess_res = None
@@ -893,33 +890,37 @@ def _mla_preprocess(
893
890
decode_slots = attn_metadata .slot_mapping [:num_decode_tokens ]
894
891
decode_kv_no_split = kv_no_split [:num_decode_tokens ]
895
892
decode_k_pe , decode_k_nope = self .exec_kv_decode (
896
- decode_kv_no_split , cos , sin , kv_cache ,
897
- decode_slots )
893
+ decode_kv_no_split , cos , sin , kv_cache , decode_slots )
898
894
decode_preprocess_res = DecodeMLAPreprocessResult (
899
895
decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe )
900
896
# Preprocess for prefill tokens
901
897
if has_prefill :
902
- prefill_kv_no_split = kv_no_split [num_decode_tokens :num_actual_tokens ]
898
+ prefill_kv_no_split = kv_no_split [
899
+ num_decode_tokens :num_actual_tokens ]
903
900
prefill_q_c = q_c [num_decode_tokens :num_actual_tokens ]
904
901
prefill_q = self .q_proj (prefill_q_c )[0 ]\
905
902
.view (- 1 , self .num_heads , self .qk_head_dim )
906
903
prefill_q_pe = prefill_q [..., self .qk_nope_head_dim :]
907
904
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
908
905
cos = attn_metadata .prefill .cos
909
906
sin = attn_metadata .prefill .sin
910
- prefill_slots = attn_metadata .slot_mapping [num_decode_tokens :num_actual_tokens ]
907
+ prefill_slots = attn_metadata .slot_mapping [
908
+ num_decode_tokens :num_actual_tokens ]
911
909
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
912
910
prefill_k_pe , prefill_k_c_normed = self .exec_kv_prefill (
913
- prefill_kv_no_split , cos , sin , kv_cache ,
914
- prefill_slots )
915
- prefill_k_pe = prefill_k_pe .view (prefill_q_c .shape [0 ], self .num_kv_heads ,
916
- - 1 )
917
- prefill_k_nope , prefill_value = self .kv_b_proj (prefill_k_c_normed )[0 ].view (
918
- - 1 , self .num_heads , self .qk_nope_head_dim + self .v_head_dim ).split (
919
- [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
920
- prefill_k_pe = prefill_k_pe .expand ((* prefill_k_nope .shape [:- 1 ], - 1 ))
911
+ prefill_kv_no_split , cos , sin , kv_cache , prefill_slots )
912
+ prefill_k_pe = prefill_k_pe .view (prefill_q_c .shape [0 ],
913
+ self .num_kv_heads , - 1 )
914
+ prefill_k_nope , prefill_value = self .kv_b_proj (
915
+ prefill_k_c_normed )[0 ].view (
916
+ - 1 , self .num_heads ,
917
+ self .qk_nope_head_dim + self .v_head_dim ).split (
918
+ [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
919
+ prefill_k_pe = prefill_k_pe .expand (
920
+ (* prefill_k_nope .shape [:- 1 ], - 1 ))
921
921
prefill_preprocess_res = PrefillMLAPreprocessResult (
922
- prefill_q_nope , prefill_q_pe , prefill_k_nope , prefill_k_pe , prefill_value )
922
+ prefill_q_nope , prefill_q_pe , prefill_k_nope , prefill_k_pe ,
923
+ prefill_value )
923
924
return decode_preprocess_res , prefill_preprocess_res
924
925
925
926
def forward (
@@ -972,13 +973,10 @@ def forward(
972
973
# FIX: aicore move should be also placed on the comm stream in dbo,
973
974
# otherwise it may affect the accuracy
974
975
# TODO: use an elegant way to overlap
975
- output_prefill = self ._forward_prefill (prefill_preprocess_res .q_nope ,
976
- prefill_preprocess_res .q_pe ,
977
- prefill_preprocess_res .k_nope ,
978
- prefill_preprocess_res .k_pe ,
979
- prefill_preprocess_res .value ,
980
- kv_cache ,
981
- attn_metadata )
976
+ output_prefill = self ._forward_prefill (
977
+ prefill_preprocess_res .q_nope , prefill_preprocess_res .q_pe ,
978
+ prefill_preprocess_res .k_nope , prefill_preprocess_res .k_pe ,
979
+ prefill_preprocess_res .value , kv_cache , attn_metadata )
982
980
current_ms_metadata = get_multistream_comm_context ()
983
981
if current_ms_metadata is not None :
984
982
with torch .npu .stream (current_ms_metadata .comm_stream ):
0 commit comments