Skip to content

Commit a531f36

Browse files
committed
[Bugfix] Fix specdecoding in chunkedprefill scenario
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent 0c04bf1 commit a531f36

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -495,11 +495,7 @@ def __init__(
495495
self.ring_mla_mask_size = 512
496496
self.prefill_mask = None
497497

498-
# Adapt torch air graph mode with spec decoding.
499-
speculative_config = vllm_config.speculative_config
500-
if speculative_config is not None:
501-
self.spec_token_num = speculative_config.num_speculative_tokens
502-
assert self.spec_token_num > 0
498+
self.speculative_config = vllm_config.speculative_config
503499

504500
def _v_up_proj(self, x):
505501
# Convert from (B, N, L) to (N, B, L)
@@ -811,7 +807,11 @@ def _forward_decode(
811807
self.qk_rope_head_dim)
812808
input_layout = "BNSD"
813809

814-
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
810+
if attn_metadata.attn_state in [
811+
AscendAttentionState.SpecDecoding,
812+
AscendAttentionState.ChunkedPrefill
813+
] and self.speculative_config is not None:
814+
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
815815
input_layout = "TND"
816816
# [bs * q_seq_len, num_heads_per_rank, dim]
817817
q_nope = q_nope.view(num_tokens, self.num_heads, -1)

vllm_ascend/torchair/torchair_mla.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -676,11 +676,7 @@ def __init__(
676676
self.prefill_mask = None
677677
self.ring_mla_mask_size = 512
678678

679-
# Adapt torch air graph mode with spec decoding.
680-
speculative_config = get_current_vllm_config().speculative_config
681-
if speculative_config is not None:
682-
self.spec_token_num = speculative_config.num_speculative_tokens
683-
assert self.spec_token_num > 0
679+
self.speculative_config = get_current_vllm_config().speculative_config
684680

685681
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
686682
# Convert from (B, N, L) to (N, B, L)
@@ -1012,7 +1008,11 @@ def _forward_decode(
10121008
self.qk_rope_head_dim)
10131009
input_layout = "BNSD"
10141010

1015-
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
1011+
if attn_metadata.attn_state in [
1012+
AscendAttentionState.SpecDecoding,
1013+
AscendAttentionState.ChunkedPrefill
1014+
] and self.speculative_config is not None:
1015+
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
10161016
input_layout = "TND"
10171017
# [bs * q_seq_len, num_heads_per_rank, dim]
10181018
q_nope = q_nope.view(num_tokens, self.num_heads, -1)

0 commit comments

Comments
 (0)