diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 3417bb87fb..5cdfb9e13e 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -321,15 +321,17 @@ def forward( assert attn_metadata is not None assert attn_metadata.attn_mask is not None mask = attn_metadata.attn_mask - torch_npu._npu_flash_attention(query=query, - key=key, - value=value, - mask=mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output) + torch_npu.atb._npu_flash_attention_v2( + query=query, + key=key, + value=value, + mask=mask, + mask_type=3, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: assert attn_metadata is not None assert attn_metadata.attn_mask is not None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 25e151e6ef..78f4f86599 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -657,13 +657,9 @@ def _make_attention_mask(self, seq_lens, query_lens, position, if attn_state == AscendAttentionState.ChunkedPrefill: return self.attn_mask_builder.get_splitfuse_attn_mask( seq_lens, query_lens, position, self.dtype, self.device) - # Prefill without cache situation. - elif attn_state == AscendAttentionState.PrefillNoCache: - max_seq_len = max(seq_lens, default=0) - return self.attn_mask_builder.get_attn_mask( - max_seq_len, self.dtype, self.device) - # Prefill with cache hit. - elif attn_state == AscendAttentionState.PrefillCacheHit: + # Prefill situation. + elif attn_state == AscendAttentionState.PrefillNoCache or \ + attn_state == AscendAttentionState.PrefillCacheHit: return self.attn_mask_builder.get_attn_mask( 128, self.dtype, self.device) # Decode-only situation.