Skip to content

Commit 82a38b0

Browse files
committed
fix ut
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent d3d5f57 commit 82a38b0

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -484,13 +484,8 @@ def __init__(
484484
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
485485

486486
vllm_config = get_current_vllm_config()
487-
RING_MLA_MASK_SIZE = 512
488-
self.prefill_mask = torch.triu(
489-
torch.ones(RING_MLA_MASK_SIZE,
490-
RING_MLA_MASK_SIZE,
491-
device="npu",
492-
dtype=vllm_config.model_config.dtype),
493-
1)
487+
self.ring_mla_mask_size = 512
488+
self.prefill_mask = None
494489

495490
# Adapt torch air graph mode with spec decoding.
496491
speculative_config = vllm_config.speculative_config
@@ -686,6 +681,13 @@ def _forward_prefill(
686681
num_tokens,
687682
dtype=torch.float32,
688683
device=q_nope.device)
684+
if self.prefill_mask is None:
685+
self.prefill_mask = torch.triu(
686+
torch.ones(self.ring_mla_mask_size,
687+
self.ring_mla_mask_size,
688+
device=q_nope.device,
689+
dtype=q_nope.dtype),
690+
1)
689691
torch_npu.atb.npu_ring_mla(
690692
q_nope=q_nope,
691693
q_rope=q_pe,

0 commit comments

Comments
 (0)