Skip to content

Commit 1c9fc7a

Browse files
committed
feat: use attention state to mark spec decoding
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent d91a306 commit 1c9fc7a

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class AscendAttentionState(Enum):
100100
PrefillCacheHit = 1
101101
DecodeOnly = 2
102102
ChunkedPrefill = 3
103+
SpecDecoding = 4
103104

104105

105106
@dataclass

vllm_ascend/attention/mla_v1.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def __init__(self,
170170
if metadata_cls is not None else AscendMLAMetadata # type: ignore
171171
self.runner = runner
172172
scheduler_config = runner.scheduler_config
173-
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
173+
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
174174
ascend_config = get_ascend_config()
175175
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
176176

@@ -477,13 +477,7 @@ def __init__(
477477
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
478478
# Adapt torch air graph mode with spec decoding.
479479
speculative_config = get_current_vllm_config().speculative_config
480-
self.fia_sparse_mode = 0
481-
self.use_spec_decode = False
482-
# We need to set the sparse_mode of fused_infer_attention op to 3
483-
# in spec decoding scenario in order to pass in attention mask.
484480
if speculative_config is not None:
485-
self.fia_sparse_mode = 3
486-
self.use_spec_decode = True
487481
self.spec_token_num = speculative_config.num_speculative_tokens
488482
assert self.spec_token_num > 0
489483

@@ -575,7 +569,10 @@ def _forward_prefill(
575569
num_tokens = query.size(0)
576570
attn_output = None
577571
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
578-
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
572+
if attn_metadata.attn_state in [
573+
AscendAttentionState.ChunkedPrefill,
574+
AscendAttentionState.SpecDecoding
575+
]:
579576
attn_output = torch.empty(num_tokens,
580577
self.num_heads * self.v_head_dim,
581578
dtype=query.dtype,
@@ -622,7 +619,7 @@ def _forward_prefill(
622619
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
623620
else:
624621
raise RuntimeError(
625-
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
622+
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
626623
)
627624
attn_output = attn_output.reshape(
628625
[num_tokens, self.num_heads * self.v_head_dim])
@@ -696,7 +693,7 @@ def _forward_decode(
696693
device=q.device)
697694
if self.running_in_graph:
698695
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
699-
if self.use_spec_decode:
696+
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
700697
assert num_tokens % self.spec_token_num == 0
701698
q_nope = (q_nope.view(
702699
num_tokens // (self.spec_token_num + 1),
@@ -710,9 +707,13 @@ def _forward_decode(
710707
self.num_heads,
711708
-1,
712709
).transpose(1, 2).contiguous())
710+
sparse_mode = 3
711+
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
713712
else:
714713
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
715714
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
715+
sparse_mode = 0
716+
spec_attn_mask = None
716717
# shape of knope/k_pe for npu graph mode should be:
717718
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
718719
block_size = kv_c_and_k_pe_cache[0].shape[1]
@@ -730,8 +731,8 @@ def _forward_decode(
730731
num_heads=self.num_heads,
731732
num_key_value_heads=self.num_kv_heads,
732733
input_layout="BNSD",
733-
atten_mask=attn_metadata.decode.attn_mask, # type:ignore
734-
sparse_mode=self.fia_sparse_mode,
734+
atten_mask=spec_attn_mask,
735+
sparse_mode=sparse_mode,
735736
scale=self.scale,
736737
antiquant_mode=0,
737738
antiquant_scale=None,
@@ -773,7 +774,9 @@ def forward(
773774
if attn_metadata is None:
774775
# Profiling run.
775776
return output
776-
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
777+
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
778+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
779+
]
777780
num_actual_toks = attn_metadata.num_actual_tokens
778781
if k_pe is None and not self.running_in_graph:
779782
kv_c, k_pe = self.kv_a_proj_with_mqa(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -851,10 +851,13 @@ def _process_reqs(
851851
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
852852
attn_state = AscendAttentionState.PrefillNoCache
853853
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
854-
elif (np.all(num_valid_tokens == 1) and self.enable_torchair_graph_mode) or (np.all(num_scheduled_tokens == 1)):
854+
elif np.all(num_scheduled_tokens == 1):
855855
attn_state = AscendAttentionState.DecodeOnly
856+
# Speculative decoding.
857+
elif np.all(num_valid_tokens == 1):
858+
attn_state = AscendAttentionState.SpecDecoding
856859
# splitfuse
857-
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled or use_spec_decode:
860+
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
858861
attn_state = AscendAttentionState.ChunkedPrefill
859862
else:
860863
attn_state = AscendAttentionState.PrefillCacheHit
@@ -883,7 +886,9 @@ def _process_reqs(
883886
seq_lens = self.seq_lens[:num_reqs]
884887
common_attn_metadata = CommonAttentionMetadata(
885888
query_start_loc=query_start_loc, seq_lens=seq_lens)
886-
with_prefill = attn_state != AscendAttentionState.DecodeOnly
889+
with_prefill = attn_state not in [
890+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
891+
]
887892

888893
if self.dp_size > 1:
889894
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(

0 commit comments

Comments
 (0)