From c95366c1406ef9f0324ba0b93ca105938ceedd27 Mon Sep 17 00:00:00 2001 From: p00465316 Date: Fri, 12 Sep 2025 22:19:34 +0800 Subject: [PATCH] [Bugfix]:replace npu_incre_flash_attention with npu_fused_infer_attention_score Signed-off-by: p00465316 --- tests/ut/torchair/test_torchair_attention.py | 95 ++++++++++++++++++++ vllm_ascend/torchair/torchair_attention.py | 25 ++++-- 2 files changed, 111 insertions(+), 9 deletions(-) create mode 100644 tests/ut/torchair/test_torchair_attention.py diff --git a/tests/ut/torchair/test_torchair_attention.py b/tests/ut/torchair/test_torchair_attention.py new file mode 100644 index 0000000000..dd262dc898 --- /dev/null +++ b/tests/ut/torchair/test_torchair_attention.py @@ -0,0 +1,95 @@ +from unittest.mock import MagicMock, patch + +import torch +from vllm.attention.backends.abstract import AttentionType +from vllm.distributed.parallel_state import GroupCoordinator + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.torchair.torchair_attention import \ + AscendAttentionTorchairBackendImpl + + +class TestAscendAttentionTorchairBackendImpl(TestBase): + + @patch("torch.zeros") + @patch('vllm.distributed.parallel_state._TP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) # TODO + @patch("vllm.distributed.get_tensor_model_parallel_world_size", + return_value=2) # TODO + @patch("vllm.config.get_current_vllm_config") # TODO + @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") # TODO + def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp, + mock_zeros): + mock_tp.world_size = 2 # TODO + ascend_config.torchair_graph_config.enabled = True # TODO + ascend_config.torchair_graph_config.enable_kv_nz = False # TODO + speculative_config = MagicMock() + speculative_config.num_speculative_tokens = 4 + vllm_config.speculative_config = speculative_config + + num_heads = 32 + head_size = 128 # TODO + scale = 0.1 # TODO + num_kv_heads = 4 + kv_cache_dtype = "auto" + attn_type = AttentionType.DECODER + mock_zeros.return_value = torch.ones((), + device='cpu', + dtype=torch.int32) + + self.impl = AscendAttentionTorchairBackendImpl( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=kv_cache_dtype, + blocksparse_params=None, + logits_soft_cap=None, + attn_type=attn_type, + kv_sharing_target_layer_name=None) + + @patch("torch_npu.npu_scatter_nd_update_") + @patch("torch_npu.npu_fused_infer_attention_score") + def test_forward_with_decode_only(self, mock_fused, _): + layer = MagicMock() + layer._k_scale_float = 1.0 + layer._v_scale_float = 1.0 + + seq_len = 1 + num_tokens = 100 + num_blocks = 256 + block_size = 4 + + query = torch.randn(num_tokens, seq_len, + self.impl.num_heads * self.impl.head_size) + key = torch.randn(num_tokens, seq_len, + self.impl.num_kv_heads * self.impl.head_size) + value = torch.randn(num_tokens, seq_len, + self.impl.num_kv_heads * self.impl.head_size) + kv_cache = (torch.randn(num_blocks, block_size, + self.impl.num_heads * self.impl.head_size), + torch.randn(num_blocks, block_size, + self.impl.num_heads * self.impl.head_size)) + output = torch.randn(num_tokens, self.impl.num_heads, + self.impl.head_size) + + decode = MagicMock() # TODO + decode.seq_lens_list = [2] * num_tokens + decode.block_table = torch.ones(num_tokens, 8, dtype=torch.int32) + decode.attn_mask = None + + metadata = MagicMock() + metadata.attn_state = AscendAttentionState.DecodeOnly + metadata.slot_mapping = torch.arange(num_tokens, dtype=torch.int32) + metadata.decode = decode + + mock_fused.return_value = (torch.ones(num_tokens, self.impl.num_heads, + self.impl.head_size), + torch.ones(1)) + + result = self.impl.forward(layer, query, key, value, kv_cache, + metadata, output, False) + self.assertEqual(result.shape[0], num_tokens) diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index d2443ad442..8d454636c4 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -436,17 +436,24 @@ def forward( block_size = key_cache.shape[1] query = query.view(num_tokens, 1, self.num_heads * self.head_size).contiguous() - output = torch_npu.npu_incre_flash_attention( - query, - key_cache, - value_cache, - num_key_value_heads=self.num_kv_heads, + output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key_cache, + value=value_cache, + query_rope=None, + key_rope=None, num_heads=self.num_heads, - actual_seq_lengths=seq_lens, - scale_value=self.scale, - block_table=block_table, + num_key_value_heads=self.num_kv_heads, input_layout='BSH', - block_size=block_size) + atten_mask=decode_meta.attn_mask, + sparse_mode=0, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=seq_lens, + ) else: raise NotImplementedError( "Torchair graph mode with non-MLA attention backend is still experimental."