20
20
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
21
21
from vllm_ascend .multistream .context import get_multistream_comm_context
22
22
from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
23
+ from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
23
24
from vllm_ascend .utils import npu_prefetch
24
25
from vllm_ascend .worker .npu_input_batch import InputBatch
25
26
@@ -654,43 +655,67 @@ def _forward_prefill(
654
655
) -> torch .Tensor :
655
656
assert attn_metadata .prefill is not None
656
657
assert len (kv_c_and_k_pe_cache ) > 1
658
+ query = torch .cat ([q_nope , q_pe ], dim = - 1 )
657
659
num_tokens = q_nope .size (0 )
658
660
attn_output = torch .empty (num_tokens ,
659
661
self .num_heads ,
660
662
self .v_head_dim ,
661
- dtype = q_nope .dtype ,
662
- device = q_nope .device )
663
- if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
664
- query = torch .cat ((q_nope , q_pe ), dim = - 1 )
665
- key = torch .cat ((k_nope , k_pe ), dim = - 1 )
666
- torch_npu ._npu_flash_attention (
663
+ dtype = query .dtype ,
664
+ device = query .device )
665
+ k_pe = k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))
666
+ # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
667
+ ascend_config = get_ascend_config ()
668
+
669
+ if attn_metadata .attn_state in [
670
+ AscendAttentionState .ChunkedPrefill ,
671
+ AscendAttentionState .SpecDecoding ,
672
+ AscendAttentionState .PrefillCacheHit
673
+ ] and not ascend_config .chunked_prefill_for_mla :
674
+
675
+ attn_output_torch = torch .empty (num_tokens ,
676
+ self .num_heads * self .v_head_dim ,
677
+ dtype = query .dtype ,
678
+ device = query .device )
679
+ # current requests is chunked in prefill, disable flash attention with chunked prefill
680
+ vanilla_chunked_prefill_mla (
681
+ output = attn_output_torch ,
667
682
query = query ,
668
- key = key ,
669
- value = value ,
670
- mask = attn_metadata .attn_mask ,
671
- seq_len = attn_metadata .prefill .context_lens ,
672
- scale_value = self .scale ,
673
- num_heads = self .num_heads ,
674
- num_kv_heads = self .num_heads ,
675
- out = attn_output )
676
- else :
683
+ kv_cache = kv_c_and_k_pe_cache ,
684
+ block_tables = attn_metadata .prefill .block_table ,
685
+ query_lens = attn_metadata .prefill .query_lens ,
686
+ context_lens = attn_metadata .prefill .context_lens ,
687
+ kv_b_proj = self .kv_b_proj ,
688
+ max_query_len = attn_metadata .prefill .max_query_len ,
689
+ max_context_len = attn_metadata .prefill .max_seq_lens ,
690
+ nope_dim = self .qk_nope_head_dim ,
691
+ rope_dim = self .qk_rope_head_dim ,
692
+ v_head_dim = self .v_head_dim ,
693
+ scale = self .scale ,
694
+ alibi_slopes = None ,
695
+ causal = True )
696
+ elif attn_metadata .attn_state in [
697
+ AscendAttentionState .ChunkedPrefill ,
698
+ AscendAttentionState .SpecDecoding ,
699
+ AscendAttentionState .PrefillCacheHit
700
+ ]:
701
+ query = torch .cat ([q_nope , q_pe ], dim = - 1 )
677
702
attn_lse = torch .empty (self .num_heads ,
678
703
num_tokens ,
679
704
dtype = torch .float32 ,
680
705
device = q_nope .device )
681
- self . prefill_mask = torch .triu (
682
- torch .ones (512 , 512 , device = q_nope .device , dtype = q_nope .dtype ),
706
+ mask = torch .triu (
707
+ torch .ones (512 , 512 , device = query .device , dtype = query .dtype ),
683
708
1 ) # 512: mask only support 512
684
709
if attn_metadata .num_prefills > 1 :
685
- self . prefill_mask = self . prefill_mask . unsqueeze (0 ).repeat (
686
- attn_metadata . num_prefills , 1 , 1 )
710
+ mask = mask . unsqueeze (0 ).repeat (attn_metadata . num_prefills , 1 ,
711
+ 1 )
687
712
torch_npu .atb .npu_ring_mla (
688
713
q_nope = q_nope ,
689
714
q_rope = q_pe ,
690
715
k_nope = k_nope ,
691
716
k_rope = k_pe ,
692
717
value = value ,
693
- mask = self . prefill_mask ,
718
+ mask = mask ,
694
719
seqlen = torch .tensor (attn_metadata .prefill .query_lens ,
695
720
dtype = torch .int32 ),
696
721
head_num = self .num_heads ,
@@ -705,10 +730,34 @@ def _forward_prefill(
705
730
output = attn_output ,
706
731
softmax_lse = attn_lse )
707
732
attn_output , attn_lse = self ._compute_prefill_context ( \
708
- q_nope , q_pe , kv_c_and_k_pe_cache , self .qk_rope_head_dim , attn_metadata , attn_output , attn_lse )
733
+ query , kv_c_and_k_pe_cache , self .qk_rope_head_dim , attn_metadata , attn_output , attn_lse )
709
734
735
+ elif attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
736
+ key = torch .cat ((k_nope , k_pe ), dim = - 1 )
737
+ torch_npu ._npu_flash_attention (
738
+ query = query ,
739
+ key = key ,
740
+ value = value ,
741
+ mask = attn_metadata .attn_mask ,
742
+ seq_len = attn_metadata .prefill .context_lens ,
743
+ scale_value = self .scale ,
744
+ num_heads = self .num_heads ,
745
+ num_kv_heads = self .num_heads ,
746
+ out = attn_output )
747
+ attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
748
+ else :
749
+ raise RuntimeError (
750
+ "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
751
+ )
710
752
attn_output = attn_output .reshape (
711
753
[num_tokens , self .num_heads * self .v_head_dim ])
754
+ if attn_metadata .attn_state in [
755
+ AscendAttentionState .ChunkedPrefill ,
756
+ AscendAttentionState .SpecDecoding ,
757
+ AscendAttentionState .PrefillCacheHit
758
+ ] and not ascend_config .chunked_prefill_for_mla :
759
+ attn_output = attn_output_torch
760
+
712
761
return attn_output
713
762
714
763
def exec_kv_decode (
0 commit comments