Skip to content

Commit 5c35f61

Browse files
tt545571022sunchend
authored andcommitted
Eliminate the redundant handling of index_select in get_splitfuse_attn_mask
When max_position_embeddings are large and max_seq_len is small, index_select will move a large amount of unused data, thereby increasing the utilization rate of the cpu, and the index_select operator executed on the cpu takes a very long time. this change can eliminate the redundant handling of index_select in get_splitfuse_attn_mask. Signed-off-by: tt545571022 <tjl545571022@hotmail.com>
1 parent 05a700d commit 5c35f61

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
6262
device: torch.device):
6363
self._update_attn_cache(max_seq_len, dtype)
6464
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
65-
).to(device)
65+
).to(device, non_blocking=True)
6666

6767
def get_splitfuse_attn_mask(
6868
self,
@@ -79,11 +79,10 @@ def get_splitfuse_attn_mask(
7979
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
8080
# is not the same. Fix this in the future when kernel is ready.
8181
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
82-
attn_mask = torch.index_select(self.attn_mask_cache,
83-
dim=0,
84-
index=position)[:, :max_seq_len]
82+
attn_mask = self.attn_mask_cache[position, :max_seq_len].to(
83+
device, non_blocking=True)
8584
attn_mask *= mask_scale_factor
86-
return attn_mask.contiguous().to(device, non_blocking=True)
85+
return attn_mask
8786

8887
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
8988
if seqlen > self._seq_len_cached:

0 commit comments

Comments
 (0)