Skip to content

Commit 5446db6

Browse files
committed
chunked prefill splitfuse算子接入
1 parent 88ca8a0 commit 5446db6

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -455,18 +455,25 @@ def _forward_v1_style(
455455
attn_metadata.seq_lens = \
456456
attn_metadata.seq_lens.to(device=query.device)
457457

458-
torch_npu._npu_paged_attention_splitfuse(
458+
num_block, block_size, head_num, head_dim = self.key_cache.shape
459+
key = self.key_cache.view(num_block, block_size, -1)
460+
value = self.value_cache.view(num_block, block_size, -1)
461+
462+
output, _ = torch_npu.npu_fused_infer_attention_score(
459463
query=query,
460-
key_cache=self.key_cache,
461-
value_cache=self.value_cache,
462-
mask=attn_metadata.attn_mask,
464+
key=key,
465+
value=value,
466+
atten_mask=attn_metadata.attn_mask.to(device=query.device),
463467
block_table=attn_metadata.block_tables,
464-
seq_len=attn_metadata.query_lens,
465-
context_lens=attn_metadata.seq_lens,
466-
num_kv_heads=self.num_kv_heads,
468+
input_layout="TND",
469+
block_size=block_size,
470+
actual_seq_lengths=attn_metadata.query_start_loc[1:],
471+
actual_seq_lengths_kv=attn_metadata.seq_lens,
472+
num_key_value_heads=self.num_kv_heads,
467473
num_heads=self.num_heads,
468-
scale_value=self.scale,
469-
out=output)
474+
scale=self.scale,
475+
sparse_mode=3,
476+
)
470477
return output
471478

472479
def forward(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,8 +797,7 @@ def _make_attention_mask(self, seq_lens, position,
797797
attn_state) -> torch.Tensor:
798798
# Chunk Prefill situation.
799799
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
800-
return self.attn_mask_builder.get_splitfuse_attn_mask(
801-
seq_lens, position, self.dtype, self.device)
800+
return torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8)
802801
# Prefill without cache situation.
803802
elif attn_state == AscendAttentionState.PrefillNoCache:
804803
max_seq_len = max(seq_lens, default=0)

0 commit comments

Comments
 (0)