20
20
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
21
21
from vllm_ascend .multistream .context import get_multistream_comm_context
22
22
from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
23
+ from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
23
24
from vllm_ascend .utils import npu_prefetch
24
25
from vllm_ascend .worker .npu_input_batch import InputBatch
25
26
@@ -184,10 +185,7 @@ def __init__(self,
184
185
self .block_size - 1 ) // self .block_size
185
186
self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
186
187
187
- if vllm_config .speculative_config is not None :
188
- self .decode_threshold = vllm_config .speculative_config .num_speculative_tokens + 1
189
- else :
190
- self .decode_threshold = 1
188
+ self .decode_threshold = 1
191
189
192
190
if self .chunked_prefill_enabled :
193
191
self .chunked_prefill_workspace_size = min (
@@ -483,6 +481,9 @@ def __init__(
483
481
self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
484
482
self .enable_mla_prefetch = ascend_config .enable_mla_prefetch
485
483
self .enable_kv_nz = ascend_config .torchair_graph_config .enable_kv_nz
484
+ self .chunked_prefill_for_mla = ascend_config .chunked_prefill_for_mla
485
+
486
+ self .prefill_mask = None
486
487
487
488
# Adapt torch air graph mode with spec decoding.
488
489
speculative_config = get_current_vllm_config ().speculative_config
@@ -673,14 +674,18 @@ def _forward_prefill(
673
674
num_heads = self .num_heads ,
674
675
num_kv_heads = self .num_heads ,
675
676
out = attn_output )
676
- else :
677
+ elif self . chunked_prefill_for_mla :
677
678
attn_lse = torch .empty (self .num_heads ,
678
679
num_tokens ,
679
680
dtype = torch .float32 ,
680
681
device = q_nope .device )
681
- self .prefill_mask = torch .triu (
682
- torch .ones (512 , 512 , device = q_nope .device , dtype = q_nope .dtype ),
683
- 1 ) # 512: mask only support 512
682
+ if self .prefill_mask is None :
683
+ self .prefill_mask = torch .triu (
684
+ torch .ones (512 ,
685
+ 512 ,
686
+ device = q_nope .device ,
687
+ dtype = q_nope .dtype ),
688
+ 1 ) # 512: mask only support 512
684
689
if attn_metadata .num_prefills > 1 :
685
690
self .prefill_mask = self .prefill_mask .unsqueeze (0 ).repeat (
686
691
attn_metadata .num_prefills , 1 , 1 )
@@ -706,9 +711,38 @@ def _forward_prefill(
706
711
softmax_lse = attn_lse )
707
712
attn_output , attn_lse = self ._compute_prefill_context ( \
708
713
q_nope , q_pe , kv_c_and_k_pe_cache , self .qk_rope_head_dim , attn_metadata , attn_output , attn_lse )
714
+ else :
715
+ query = torch .cat ((q_nope , q_pe ), dim = - 1 )
716
+ attn_output_torch = torch .empty (num_tokens ,
717
+ self .num_heads * self .v_head_dim ,
718
+ dtype = query .dtype ,
719
+ device = query .device )
720
+ # current requests is chunked in prefill, disable flash attention with chunked prefill
721
+ vanilla_chunked_prefill_mla (
722
+ output = attn_output_torch ,
723
+ query = query ,
724
+ kv_cache = kv_c_and_k_pe_cache ,
725
+ block_tables = attn_metadata .prefill .block_table ,
726
+ query_lens = attn_metadata .prefill .query_lens ,
727
+ context_lens = attn_metadata .prefill .context_lens ,
728
+ kv_b_proj = self .kv_b_proj ,
729
+ max_query_len = attn_metadata .prefill .max_query_len ,
730
+ max_context_len = attn_metadata .prefill .max_seq_lens ,
731
+ nope_dim = self .qk_nope_head_dim ,
732
+ rope_dim = self .qk_rope_head_dim ,
733
+ v_head_dim = self .v_head_dim ,
734
+ scale = self .scale ,
735
+ alibi_slopes = None ,
736
+ causal = True )
709
737
710
738
attn_output = attn_output .reshape (
711
739
[num_tokens , self .num_heads * self .v_head_dim ])
740
+ if attn_metadata .attn_state in [
741
+ AscendAttentionState .ChunkedPrefill ,
742
+ AscendAttentionState .SpecDecoding ,
743
+ AscendAttentionState .PrefillCacheHit
744
+ ] and not self .chunked_prefill_for_mla :
745
+ attn_output = attn_output_torch
712
746
return attn_output
713
747
714
748
def exec_kv_decode (
0 commit comments