Skip to content

Commit 3520d35

Browse files
committed
[V0.9.1] Replace FA ops with FA_V2 to optimize perf
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 57664f0 commit 3520d35

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -321,15 +321,16 @@ def forward(
321321
assert attn_metadata is not None
322322
assert attn_metadata.attn_mask is not None
323323
mask = attn_metadata.attn_mask
324-
torch_npu._npu_flash_attention(query=query,
325-
key=key,
326-
value=value,
327-
mask=mask,
328-
seq_len=attn_metadata.seq_lens,
329-
scale_value=self.scale,
330-
num_heads=self.num_heads,
331-
num_kv_heads=self.num_kv_heads,
332-
out=output)
324+
torch_npu.atb._npu_flash_attention_v2(query=query,
325+
key=key,
326+
value=value,
327+
mask=mask,
328+
mask_type=3,
329+
seq_len=attn_metadata.seq_lens,
330+
scale_value=self.scale,
331+
num_heads=self.num_heads,
332+
num_kv_heads=self.num_kv_heads,
333+
out=output)
333334
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
334335
assert attn_metadata is not None
335336
assert attn_metadata.attn_mask is not None

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -657,13 +657,9 @@ def _make_attention_mask(self, seq_lens, query_lens, position,
657657
if attn_state == AscendAttentionState.ChunkedPrefill:
658658
return self.attn_mask_builder.get_splitfuse_attn_mask(
659659
seq_lens, query_lens, position, self.dtype, self.device)
660-
# Prefill without cache situation.
661-
elif attn_state == AscendAttentionState.PrefillNoCache:
662-
max_seq_len = max(seq_lens, default=0)
663-
return self.attn_mask_builder.get_attn_mask(
664-
max_seq_len, self.dtype, self.device)
665-
# Prefill with cache hit.
666-
elif attn_state == AscendAttentionState.PrefillCacheHit:
660+
# Prefill situation.
661+
elif attn_state == AscendAttentionState.PrefillNoCache or \
662+
attn_state == AscendAttentionState.PrefillCacheHit:
667663
return self.attn_mask_builder.get_attn_mask(
668664
128, self.dtype, self.device)
669665
# Decode-only situation.

0 commit comments

Comments
 (0)