Skip to content

Commit ed91379

Browse files
committed
fix attn_mask bug for ring mla
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 0829b48 commit ed91379

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,13 @@ def __init__(
483483
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
484484
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
485485

486-
self.prefill_mask = None
486+
vllm_config = get_current_vllm_config()
487+
self.prefill_mask = torch.triu(
488+
torch.ones(512,
489+
512,
490+
device="npu",
491+
dtype=vllm_config.model_config.dtype),
492+
1) # 512: mask only support 512
487493

488494
# Adapt torch air graph mode with spec decoding.
489495
speculative_config = get_current_vllm_config().speculative_config
@@ -679,16 +685,6 @@ def _forward_prefill(
679685
num_tokens,
680686
dtype=torch.float32,
681687
device=q_nope.device)
682-
if self.prefill_mask is None:
683-
self.prefill_mask = torch.triu(
684-
torch.ones(512,
685-
512,
686-
device=q_nope.device,
687-
dtype=q_nope.dtype),
688-
1) # 512: mask only support 512
689-
if attn_metadata.num_prefills > 1:
690-
self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat(
691-
attn_metadata.num_prefills, 1, 1)
692688
torch_npu.atb.npu_ring_mla(
693689
q_nope=q_nope,
694690
q_rope=q_pe,

0 commit comments

Comments
 (0)