Skip to content

Commit 782e774

Browse files
committed
splitfuse access optimize
1 parent 5446db6 commit 782e774

File tree

3 files changed

+5
-18
lines changed

3 files changed

+5
-18
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444

4545
self._seq_len_cached = attn_mask.shape[0]
4646
self.attn_mask_cache = attn_mask
47+
self.chunked_prefill_attn_mask = torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8)
4748

4849
@staticmethod
4950
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
@@ -66,24 +67,9 @@ def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
6667

6768
def get_splitfuse_attn_mask(
6869
self,
69-
seq_lens: torch.Tensor,
70-
position: torch.Tensor,
71-
dtype: torch.dtype,
7270
device: torch.device,
7371
) -> torch.Tensor:
74-
if dtype not in [torch.float16, torch.bfloat16]:
75-
raise ValueError(
76-
"splitfuse_attn_mask now only supports bf16 and fp16")
77-
max_seq_len = max(seq_lens, default=0)
78-
self._update_attn_cache(max_seq_len, dtype)
79-
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
80-
# is not the same. Fix this in the future when kernel is ready.
81-
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
82-
attn_mask = torch.index_select(self.attn_mask_cache,
83-
dim=0,
84-
index=position)[:, :max_seq_len]
85-
attn_mask *= mask_scale_factor
86-
return attn_mask.contiguous().to(device, non_blocking=True)
72+
return self.chunked_prefill_attn_mask.to(device)
8773

8874
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
8975
if seqlen > self._seq_len_cached:

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def _forward_v1_style(
463463
query=query,
464464
key=key,
465465
value=value,
466-
atten_mask=attn_metadata.attn_mask.to(device=query.device),
466+
atten_mask=attn_metadata.attn_mask,
467467
block_table=attn_metadata.block_tables,
468468
input_layout="TND",
469469
block_size=block_size,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,8 @@ def _make_attention_mask(self, seq_lens, position,
797797
attn_state) -> torch.Tensor:
798798
# Chunk Prefill situation.
799799
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
800-
return torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8)
800+
return self.attn_mask_builder.get_splitfuse_attn_mask(
801+
self.device)
801802
# Prefill without cache situation.
802803
elif attn_state == AscendAttentionState.PrefillNoCache:
803804
max_seq_len = max(seq_lens, default=0)

0 commit comments

Comments
 (0)