@@ -170,7 +170,7 @@ def __init__(self,
170
170
if metadata_cls is not None else AscendMLAMetadata # type: ignore
171
171
self .runner = runner
172
172
scheduler_config = runner .scheduler_config
173
- self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
173
+ self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
174
174
ascend_config = get_ascend_config ()
175
175
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
176
176
@@ -477,13 +477,7 @@ def __init__(
477
477
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
478
478
# Adapt torch air graph mode with spec decoding.
479
479
speculative_config = get_current_vllm_config ().speculative_config
480
- self .fia_sparse_mode = 0
481
- self .use_spec_decode = False
482
- # We need to set the sparse_mode of fused_infer_attention op to 3
483
- # in spec decoding scenario in order to pass in attention mask.
484
480
if speculative_config is not None :
485
- self .fia_sparse_mode = 3
486
- self .use_spec_decode = True
487
481
self .spec_token_num = speculative_config .num_speculative_tokens
488
482
assert self .spec_token_num > 0
489
483
@@ -575,7 +569,10 @@ def _forward_prefill(
575
569
num_tokens = query .size (0 )
576
570
attn_output = None
577
571
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
578
- if attn_metadata .attn_state == AscendAttentionState .ChunkedPrefill :
572
+ if attn_metadata .attn_state in [
573
+ AscendAttentionState .ChunkedPrefill ,
574
+ AscendAttentionState .SpecDecoding
575
+ ]:
579
576
attn_output = torch .empty (num_tokens ,
580
577
self .num_heads * self .v_head_dim ,
581
578
dtype = query .dtype ,
@@ -622,7 +619,7 @@ def _forward_prefill(
622
619
attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
623
620
else :
624
621
raise RuntimeError (
625
- "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
622
+ "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
626
623
)
627
624
attn_output = attn_output .reshape (
628
625
[num_tokens , self .num_heads * self .v_head_dim ])
@@ -696,7 +693,7 @@ def _forward_decode(
696
693
device = q .device )
697
694
if self .running_in_graph :
698
695
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
699
- if self . use_spec_decode :
696
+ if attn_metadata . attn_state == AscendAttentionState . SpecDecoding :
700
697
assert num_tokens % self .spec_token_num == 0
701
698
q_nope = (q_nope .view (
702
699
num_tokens // (self .spec_token_num + 1 ),
@@ -710,9 +707,13 @@ def _forward_decode(
710
707
self .num_heads ,
711
708
- 1 ,
712
709
).transpose (1 , 2 ).contiguous ())
710
+ sparse_mode = 3
711
+ spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
713
712
else :
714
713
q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
715
714
q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
715
+ sparse_mode = 0
716
+ spec_attn_mask = None
716
717
# shape of knope/k_pe for npu graph mode should be:
717
718
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
718
719
block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
@@ -730,8 +731,8 @@ def _forward_decode(
730
731
num_heads = self .num_heads ,
731
732
num_key_value_heads = self .num_kv_heads ,
732
733
input_layout = "BNSD" ,
733
- atten_mask = attn_metadata . decode . attn_mask , # type:ignore
734
- sparse_mode = self . fia_sparse_mode ,
734
+ atten_mask = spec_attn_mask ,
735
+ sparse_mode = sparse_mode ,
735
736
scale = self .scale ,
736
737
antiquant_mode = 0 ,
737
738
antiquant_scale = None ,
@@ -773,7 +774,9 @@ def forward(
773
774
if attn_metadata is None :
774
775
# Profiling run.
775
776
return output
776
- self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state == AscendAttentionState .DecodeOnly
777
+ self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
778
+ AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
779
+ ]
777
780
num_actual_toks = attn_metadata .num_actual_tokens
778
781
if k_pe is None and not self .running_in_graph :
779
782
kv_c , k_pe = self .kv_a_proj_with_mqa (
0 commit comments