Skip to content

Commit 7eca06d

Browse files
fems14wangxiaoxin (A)
authored andcommitted
[cherry-pick][0.9.1] vllm-ascend support chunked prefill (vllm-project#1240)
vllm-ascend support chunked prefill for MLA main 关联pr:vllm-project#1172 --------- <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: fems14 <1804143737@qq.com>
1 parent 7a63e5c commit 7eca06d

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed
File renamed without changes.

vllm_ascend/attention/mla_v1.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,83 @@ def _compute_prefill_context(
730730
softmax_lse=prefix_lse)
731731
return prefix_output, prefix_lse
732732

733+
def _compute_prefill_context(
734+
self,
735+
query: torch.Tensor,
736+
kv_c_and_k_pe_cache: torch.Tensor,
737+
rope_dim: int,
738+
attn_metadata: AscendMLAMetadata,
739+
prefix_output: torch.Tensor,
740+
prefix_lse: torch.Tensor,
741+
):
742+
prefill_metadata = attn_metadata.prefill
743+
if prefill_metadata is None or prefill_metadata.chunked_context is None:
744+
return prefix_output, prefix_lse
745+
746+
iters = len(prefill_metadata.chunked_context.seq_tot)
747+
q_pe = query[..., self.qk_nope_head_dim:]
748+
q_nope = query[..., :self.qk_nope_head_dim]
749+
750+
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
751+
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
752+
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
753+
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
754+
for i in range(iters):
755+
toks = prefill_metadata.chunked_context.seq_tot[i]
756+
757+
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
758+
seq_len = torch.stack([seq_len1, seq_len2])
759+
kv_c_normed = torch.empty(toks,
760+
kv_c_and_k_pe_cache.size(2),
761+
latent_kv_dim,
762+
dtype=query.dtype,
763+
device=query.device)
764+
k_pe = torch.empty(toks,
765+
kv_c_and_k_pe_cache.size(2),
766+
rope_dim,
767+
dtype=query.dtype,
768+
device=query.device)
769+
770+
torch_npu.atb.npu_paged_cache_load(
771+
cache_kv_c,
772+
cache_k_pe,
773+
prefill_metadata.block_table,
774+
seq_len2.to(query.device),
775+
seq_starts=prefill_metadata.chunked_context.starts[i],
776+
key=kv_c_normed,
777+
value=k_pe,
778+
)
779+
780+
kv_c_normed = kv_c_normed.squeeze()
781+
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
782+
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
783+
k_nope, v = kv_nope\
784+
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
785+
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
786+
mask = torch.triu(
787+
torch.ones(512, 512, device=query.device, dtype=query.dtype),
788+
1)
789+
torch_npu.atb.npu_ring_mla(
790+
q_nope=q_nope,
791+
q_rope=q_pe,
792+
k_nope=k_nope,
793+
k_rope=k_pe,
794+
value=v,
795+
mask=mask,
796+
seqlen=seq_len,
797+
head_num=self.num_heads,
798+
kv_head_num=self.num_heads,
799+
pre_out=prefix_output,
800+
prev_lse=prefix_lse,
801+
qk_scale=self.scale,
802+
kernel_type="kernel_type_high_precision",
803+
mask_type="no_mask",
804+
input_layout="type_bsnd",
805+
calc_type="calc_type_default",
806+
output=prefix_output,
807+
softmax_lse=prefix_lse)
808+
return prefix_output, prefix_lse
809+
733810
def _forward_prefill(
734811
self,
735812
query: torch.Tensor,

0 commit comments

Comments
 (0)