Skip to content

Commit 1e78ecb

Browse files
authored
[Perf] Add FIA interface in FA case (#3321)
### What this PR does / why we need it? Add new npu_fused_infer_attention_score op to improve perfomance in flash attention case. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: ZYang6263 <zy626375@gmail.com>
1 parent 4b3bd4f commit 1e78ecb

File tree

2 files changed

+50
-12
lines changed

2 files changed

+50
-12
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -348,15 +348,50 @@ def _forward_prefill_no_cache(
348348
mask = torch_npu.npu_format_cast(mask.contiguous(),
349349
ACL_FORMAT_FRACTAL_NZ)
350350

351-
torch_npu._npu_flash_attention(query=query,
352-
key=key,
353-
value=value,
354-
mask=mask,
355-
seq_len=attn_metadata.seq_lens,
356-
scale_value=self.scale,
357-
num_heads=self.num_heads,
358-
num_kv_heads=self.num_kv_heads,
359-
out=output)
351+
num_tokens = query.shape[0]
352+
if torch.version.cann.startswith("8.3") and self.head_size != 256:
353+
query_start_loc = attn_metadata.actual_seq_lengths_q
354+
num_tokens = query_start_loc[-1]
355+
softmax_lse = torch.empty(num_tokens,
356+
dtype=query.dtype,
357+
device=query.device)
358+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
359+
query=query,
360+
key=key,
361+
value=value,
362+
atten_mask=attn_metadata.attn_mask,
363+
input_layout="TND",
364+
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
365+
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
366+
num_key_value_heads=self.num_kv_heads,
367+
num_heads=self.num_heads,
368+
scale=self.scale,
369+
sparse_mode=3)
370+
torch_npu.npu_fused_infer_attention_score.out(
371+
query=query,
372+
key=key,
373+
value=value,
374+
atten_mask=attn_metadata.attn_mask,
375+
input_layout="TND",
376+
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
377+
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
378+
num_key_value_heads=self.num_kv_heads,
379+
num_heads=self.num_heads,
380+
scale=self.scale,
381+
sparse_mode=3,
382+
workspace=workspace,
383+
out=[output, softmax_lse])
384+
385+
else:
386+
torch_npu._npu_flash_attention(query=query,
387+
key=key,
388+
value=value,
389+
mask=mask,
390+
seq_len=attn_metadata.seq_lens,
391+
scale_value=self.scale,
392+
num_heads=self.num_heads,
393+
num_kv_heads=self.num_kv_heads,
394+
out=output)
360395
assert output is not None
361396
return output[:num_tokens, :, :]
362397

vllm_ascend/worker/model_runner_v1.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -898,9 +898,12 @@ def _make_attention_mask(self, seq_lens, position,
898898

899899
# Prefill without cache situation.
900900
elif attn_state == AscendAttentionState.PrefillNoCache:
901-
max_seq_len = max(seq_lens.max().item(), 0)
902-
return self.attn_mask_builder.get_attn_mask(
903-
max_seq_len, self.dtype, self.device)
901+
if torch.version.cann.startswith("8.3"):
902+
return self.attn_mask_builder.get_splitfuse_attn_mask()
903+
else:
904+
max_seq_len = max(seq_lens, default=0)
905+
return self.attn_mask_builder.get_attn_mask(
906+
max_seq_len, self.dtype, self.device)
904907
# Prefill with cache hit.
905908
elif attn_state == AscendAttentionState.PrefillCacheHit:
906909
return self.attn_mask_builder.get_attn_mask(

0 commit comments

Comments
 (0)