Skip to content

Commit e42217a

Browse files
roll back forward prefill
Signed-off-by: SunnyLee219 <3294305115@qq.com>
1 parent b91f45a commit e42217a

File tree

2 files changed

+72
-21
lines changed

2 files changed

+72
-21
lines changed

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def __init__(self, vllm_config):
5050
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
5151
self.enable_mla_prefetch = additional_config.get(
5252
"enable_mla_prefetch", True)
53+
self.chunked_prefill_for_mla = additional_config.get(
54+
"chunked_prefill_for_mla", False)
5355

5456

5557
class TorchairGraphConfig:

vllm_ascend/attention/mla_v1.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2121
from vllm_ascend.multistream.context import get_multistream_comm_context
2222
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
23+
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
2324
from vllm_ascend.utils import npu_prefetch
2425
from vllm_ascend.worker.npu_input_batch import InputBatch
2526

@@ -654,43 +655,67 @@ def _forward_prefill(
654655
) -> torch.Tensor:
655656
assert attn_metadata.prefill is not None
656657
assert len(kv_c_and_k_pe_cache) > 1
658+
query = torch.cat([q_nope, q_pe], dim=-1)
657659
num_tokens = q_nope.size(0)
658660
attn_output = torch.empty(num_tokens,
659661
self.num_heads,
660662
self.v_head_dim,
661-
dtype=q_nope.dtype,
662-
device=q_nope.device)
663-
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
664-
query = torch.cat((q_nope, q_pe), dim=-1)
665-
key = torch.cat((k_nope, k_pe), dim=-1)
666-
torch_npu._npu_flash_attention(
663+
dtype=query.dtype,
664+
device=query.device)
665+
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
666+
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
667+
ascend_config = get_ascend_config()
668+
669+
if attn_metadata.attn_state in [
670+
AscendAttentionState.ChunkedPrefill,
671+
AscendAttentionState.SpecDecoding,
672+
AscendAttentionState.PrefillCacheHit
673+
] and not ascend_config.chunked_prefill_for_mla:
674+
675+
attn_output_torch = torch.empty(num_tokens,
676+
self.num_heads * self.v_head_dim,
677+
dtype=query.dtype,
678+
device=query.device)
679+
# current requests is chunked in prefill, disable flash attention with chunked prefill
680+
vanilla_chunked_prefill_mla(
681+
output=attn_output_torch,
667682
query=query,
668-
key=key,
669-
value=value,
670-
mask=attn_metadata.attn_mask,
671-
seq_len=attn_metadata.prefill.context_lens,
672-
scale_value=self.scale,
673-
num_heads=self.num_heads,
674-
num_kv_heads=self.num_heads,
675-
out=attn_output)
676-
else:
683+
kv_cache=kv_c_and_k_pe_cache,
684+
block_tables=attn_metadata.prefill.block_table,
685+
query_lens=attn_metadata.prefill.query_lens,
686+
context_lens=attn_metadata.prefill.context_lens,
687+
kv_b_proj=self.kv_b_proj,
688+
max_query_len=attn_metadata.prefill.max_query_len,
689+
max_context_len=attn_metadata.prefill.max_seq_lens,
690+
nope_dim=self.qk_nope_head_dim,
691+
rope_dim=self.qk_rope_head_dim,
692+
v_head_dim=self.v_head_dim,
693+
scale=self.scale,
694+
alibi_slopes=None,
695+
causal=True)
696+
elif attn_metadata.attn_state in [
697+
AscendAttentionState.ChunkedPrefill,
698+
AscendAttentionState.SpecDecoding,
699+
AscendAttentionState.PrefillCacheHit
700+
]:
701+
query = torch.cat([q_nope, q_pe], dim=-1)
677702
attn_lse = torch.empty(self.num_heads,
678703
num_tokens,
679704
dtype=torch.float32,
680705
device=q_nope.device)
681-
self.prefill_mask = torch.triu(
682-
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype),
706+
mask = torch.triu(
707+
torch.ones(512, 512, device=query.device, dtype=query.dtype),
683708
1) # 512: mask only support 512
684709
if attn_metadata.num_prefills > 1:
685-
self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat(
686-
attn_metadata.num_prefills, 1, 1)
710+
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
711+
1)
687712
torch_npu.atb.npu_ring_mla(
688713
q_nope=q_nope,
689714
q_rope=q_pe,
690715
k_nope=k_nope,
691716
k_rope=k_pe,
692717
value=value,
693-
mask=self.prefill_mask,
718+
mask=mask,
694719
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
695720
dtype=torch.int32),
696721
head_num=self.num_heads,
@@ -705,10 +730,34 @@ def _forward_prefill(
705730
output=attn_output,
706731
softmax_lse=attn_lse)
707732
attn_output, attn_lse = self._compute_prefill_context( \
708-
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
733+
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
709734

735+
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
736+
key = torch.cat((k_nope, k_pe), dim=-1)
737+
torch_npu._npu_flash_attention(
738+
query=query,
739+
key=key,
740+
value=value,
741+
mask=attn_metadata.attn_mask,
742+
seq_len=attn_metadata.prefill.context_lens,
743+
scale_value=self.scale,
744+
num_heads=self.num_heads,
745+
num_kv_heads=self.num_heads,
746+
out=attn_output)
747+
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
748+
else:
749+
raise RuntimeError(
750+
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
751+
)
710752
attn_output = attn_output.reshape(
711753
[num_tokens, self.num_heads * self.v_head_dim])
754+
if attn_metadata.attn_state in [
755+
AscendAttentionState.ChunkedPrefill,
756+
AscendAttentionState.SpecDecoding,
757+
AscendAttentionState.PrefillCacheHit
758+
] and not ascend_config.chunked_prefill_for_mla:
759+
attn_output = attn_output_torch
760+
712761
return attn_output
713762

714763
def exec_kv_decode(

0 commit comments

Comments
 (0)