Skip to content

Commit 0a81166

Browse files
committed
optimize again
1 parent 782e774 commit 0a81166

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@ def __init__(
3939
self,
4040
max_seq_len: int,
4141
dtype: torch.dtype,
42+
device: torch.device,
4243
):
4344
attn_mask = _generate_attn_mask(max_seq_len, dtype)
4445

4546
self._seq_len_cached = attn_mask.shape[0]
4647
self.attn_mask_cache = attn_mask
47-
self.chunked_prefill_attn_mask = torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8)
48+
self.chunked_prefill_attn_mask = torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8).to(device)
4849

4950
@staticmethod
5051
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
@@ -67,9 +68,8 @@ def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
6768

6869
def get_splitfuse_attn_mask(
6970
self,
70-
device: torch.device,
7171
) -> torch.Tensor:
72-
return self.chunked_prefill_attn_mask.to(device)
72+
return self.chunked_prefill_attn_mask
7373

7474
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
7575
if seqlen > self._seq_len_cached:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
293293
)
294294

295295
self.attn_mask_builder = AttentionMaskBuilder(
296-
self.model_config.max_model_len, self.dtype)
296+
self.model_config.max_model_len, self.dtype, self.device)
297297

298298
# Set up speculative decoding.
299299
self.spec_attn_mask = None
@@ -797,8 +797,7 @@ def _make_attention_mask(self, seq_lens, position,
797797
attn_state) -> torch.Tensor:
798798
# Chunk Prefill situation.
799799
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
800-
return self.attn_mask_builder.get_splitfuse_attn_mask(
801-
self.device)
800+
return self.attn_mask_builder.get_splitfuse_attn_mask()
802801
# Prefill without cache situation.
803802
elif attn_state == AscendAttentionState.PrefillNoCache:
804803
max_seq_len = max(seq_lens, default=0)

0 commit comments

Comments
 (0)