Skip to content

Commit 75060b6

Browse files
committed
add vanilla_chunk_prefill_mla back
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 8d9677e commit 75060b6

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def __init__(self, vllm_config):
4545
ascend_scheduler_config)
4646

4747
self.expert_map_path = additional_config.get("expert_map_path", None)
48+
self.chunked_prefill_for_mla = additional_config.get(
49+
"chunked_prefill_for_mla", False)
4850
self.enable_shared_expert_dp = additional_config.get(
4951
"enable_shared_expert_dp", False
5052
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel

vllm_ascend/attention/mla_v1.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2121
from vllm_ascend.multistream.context import get_multistream_comm_context
2222
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
23+
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
2324
from vllm_ascend.utils import npu_prefetch
2425
from vllm_ascend.worker.npu_input_batch import InputBatch
2526

@@ -184,10 +185,7 @@ def __init__(self,
184185
self.block_size - 1) // self.block_size
185186
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
186187

187-
if vllm_config.speculative_config is not None:
188-
self.decode_threshold = vllm_config.speculative_config.num_speculative_tokens + 1
189-
else:
190-
self.decode_threshold = 1
188+
self.decode_threshold = 1
191189

192190
if self.chunked_prefill_enabled:
193191
self.chunked_prefill_workspace_size = min(
@@ -483,6 +481,9 @@ def __init__(
483481
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
484482
self.enable_mla_prefetch = ascend_config.enable_mla_prefetch
485483
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
484+
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
485+
486+
self.prefill_mask = None
486487

487488
# Adapt torch air graph mode with spec decoding.
488489
speculative_config = get_current_vllm_config().speculative_config
@@ -673,14 +674,18 @@ def _forward_prefill(
673674
num_heads=self.num_heads,
674675
num_kv_heads=self.num_heads,
675676
out=attn_output)
676-
else:
677+
elif self.chunked_prefill_for_mla:
677678
attn_lse = torch.empty(self.num_heads,
678679
num_tokens,
679680
dtype=torch.float32,
680681
device=q_nope.device)
681-
self.prefill_mask = torch.triu(
682-
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype),
683-
1) # 512: mask only support 512
682+
if self.prefill_mask is None:
683+
self.prefill_mask = torch.triu(
684+
torch.ones(512,
685+
512,
686+
device=q_nope.device,
687+
dtype=q_nope.dtype),
688+
1) # 512: mask only support 512
684689
if attn_metadata.num_prefills > 1:
685690
self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat(
686691
attn_metadata.num_prefills, 1, 1)
@@ -706,9 +711,38 @@ def _forward_prefill(
706711
softmax_lse=attn_lse)
707712
attn_output, attn_lse = self._compute_prefill_context( \
708713
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
714+
else:
715+
query = torch.cat((q_nope, q_pe), dim=-1)
716+
attn_output_torch = torch.empty(num_tokens,
717+
self.num_heads * self.v_head_dim,
718+
dtype=query.dtype,
719+
device=query.device)
720+
# current requests is chunked in prefill, disable flash attention with chunked prefill
721+
vanilla_chunked_prefill_mla(
722+
output=attn_output_torch,
723+
query=query,
724+
kv_cache=kv_c_and_k_pe_cache,
725+
block_tables=attn_metadata.prefill.block_table,
726+
query_lens=attn_metadata.prefill.query_lens,
727+
context_lens=attn_metadata.prefill.context_lens,
728+
kv_b_proj=self.kv_b_proj,
729+
max_query_len=attn_metadata.prefill.max_query_len,
730+
max_context_len=attn_metadata.prefill.max_seq_lens,
731+
nope_dim=self.qk_nope_head_dim,
732+
rope_dim=self.qk_rope_head_dim,
733+
v_head_dim=self.v_head_dim,
734+
scale=self.scale,
735+
alibi_slopes=None,
736+
causal=True)
709737

710738
attn_output = attn_output.reshape(
711739
[num_tokens, self.num_heads * self.v_head_dim])
740+
if attn_metadata.attn_state in [
741+
AscendAttentionState.ChunkedPrefill,
742+
AscendAttentionState.SpecDecoding,
743+
AscendAttentionState.PrefillCacheHit
744+
] and not self.chunked_prefill_for_mla:
745+
attn_output = attn_output_torch
712746
return attn_output
713747

714748
def exec_kv_decode(

0 commit comments

Comments
 (0)