Skip to content

Commit 6906790

Browse files
author
p00465316
committed
[Bugfix]:replace npu_incre_flash_attention with npu_fused_infer_attention_score
Signed-off-by: p00465316 <panchao13@huawei.com>
1 parent 756b8a1 commit 6906790

File tree

3 files changed

+151
-54
lines changed

3 files changed

+151
-54
lines changed

tests/e2e/multicard/test_torchair_graph_mode.py

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
"""
2222
import os
2323
from typing import Dict
24+
from unittest.mock import patch
2425

2526
from tests.e2e.conftest import VllmRunner
27+
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
2628

2729
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
2830

@@ -167,53 +169,47 @@ def _qwen_torchair_test_fixture(
167169
tp,
168170
enable_expert_parallel,
169171
):
170-
# The current access control does not support 16 cards,
171-
# so the MC2 operator in Qwen's graph mode cannot run.
172-
# Once 16-card support is available,
173-
# this e2e can be switched to graph mode.
174-
example_prompts = [
175-
"Hello, my name is",
176-
"The president of the United States is",
177-
"The capital of France is",
178-
"The future of AI is",
179-
]
180-
181-
additional_config = {
182-
"torchair_graph_config": {
183-
"enabled": False,
184-
},
185-
"ascend_scheduler_config": {
186-
"enabled": True,
187-
},
188-
"refresh": True,
189-
}
190172

191-
with VllmRunner(
192-
model,
193-
dtype="half",
194-
tensor_parallel_size=tp,
195-
distributed_executor_backend="mp",
196-
enforce_eager=True,
197-
additional_config=additional_config,
198-
enable_expert_parallel=enable_expert_parallel,
199-
) as vllm_model:
200-
# use greedy sampler to make sure the generated results are fix
201-
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
202-
203-
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
204-
# with 2 hidden layers, thus the golden results seems inaccurate.
205-
# This will only change if accuracy changes with the official weights
206-
# of PanguProMoE.
207-
golden_results = [
208-
'Hello, my name is Remempondeprecatedmiot忱',
209-
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
210-
'The capital of France is Rememvoud administrativ Remem投',
211-
'The future of AI isotope Segnali Zoeken精细化 supus',
212-
]
173+
def stubbed_get_state(ep_size, with_prefill, is_deepseek_v3_r1):
174+
return _get_fused_moe_state(16, with_prefill, is_deepseek_v3_r1)
175+
176+
with patch("vllm_ascend.ascend_forward_context._get_fused_moe_state",
177+
stubbed_get_state):
178+
# The current access control does not support 16 cards,
179+
# so the MC2 operator in Qwen's graph mode cannot run.
180+
# Once 16-card support is available,
181+
# this e2e can be switched to graph mode.
182+
example_prompts = [
183+
"Hello, my name is",
184+
"The president of the United States is",
185+
"The capital of France is",
186+
"The future of AI is",
187+
]
188+
189+
additional_config = {
190+
"torchair_graph_config": {
191+
"enabled": True,
192+
},
193+
"ascend_scheduler_config": {
194+
"enabled": True,
195+
},
196+
"refresh": True,
197+
}
213198

214-
assert len(golden_results) == len(vllm_output)
215-
for i in range(len(vllm_output)):
216-
print(f"Generated text: {vllm_output[i][1]!r}")
199+
with VllmRunner(
200+
model,
201+
dtype="half",
202+
tensor_parallel_size=tp,
203+
distributed_executor_backend="mp",
204+
enforce_eager=False,
205+
additional_config=additional_config,
206+
enable_expert_parallel=enable_expert_parallel,
207+
) as vllm_model:
208+
# use greedy sampler to make sure the generated results are fix
209+
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
210+
211+
for i in range(len(vllm_output)):
212+
print(f"Generated text: {vllm_output[i][1]!r}")
217213

218214

219215
def test_e2e_qwen2_with_torchair():
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import torch
4+
from vllm.attention.backends.abstract import AttentionType
5+
from vllm.distributed.parallel_state import GroupCoordinator
6+
7+
from tests.ut.base import TestBase
8+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
9+
from vllm_ascend.torchair.torchair_attention import \
10+
AscendAttentionTorchairBackendImpl
11+
12+
13+
class TestAscendAttentionTorchairBackendImpl(TestBase):
14+
15+
@patch("torch.zeros")
16+
@patch('vllm.distributed.parallel_state._TP',
17+
new_callable=lambda: MagicMock(spec=GroupCoordinator)) # TODO
18+
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
19+
return_value=2) # TODO
20+
@patch("vllm.config.get_current_vllm_config") # TODO
21+
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") # TODO
22+
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp,
23+
mock_zeros):
24+
mock_tp.world_size = 2 # TODO
25+
ascend_config.torchair_graph_config.enabled = True # TODO
26+
ascend_config.torchair_graph_config.enable_kv_nz = False # TODO
27+
speculative_config = MagicMock()
28+
speculative_config.num_speculative_tokens = 4
29+
vllm_config.speculative_config = speculative_config
30+
31+
num_heads = 32
32+
head_size = 128 # TODO
33+
scale = 0.1 # TODO
34+
num_kv_heads = 4
35+
kv_cache_dtype = "auto"
36+
attn_type = AttentionType.DECODER
37+
mock_zeros.return_value = torch.ones((),
38+
device='cpu',
39+
dtype=torch.int32)
40+
41+
self.impl = AscendAttentionTorchairBackendImpl(
42+
num_heads=num_heads,
43+
head_size=head_size,
44+
scale=scale,
45+
num_kv_heads=num_kv_heads,
46+
alibi_slopes=None,
47+
sliding_window=None,
48+
kv_cache_dtype=kv_cache_dtype,
49+
blocksparse_params=None,
50+
logits_soft_cap=None,
51+
attn_type=attn_type,
52+
kv_sharing_target_layer_name=None)
53+
54+
@patch("torch_npu.npu_scatter_nd_update_")
55+
@patch("torch_npu.npu_fused_infer_attention_score")
56+
def test_forward_with_decode_only(self, mock_fused, _):
57+
layer = MagicMock()
58+
layer._k_scale_float = 1.0
59+
layer._v_scale_float = 1.0
60+
61+
seq_len = 1
62+
num_tokens = 100
63+
num_blocks = 256
64+
block_size = 4
65+
66+
query = torch.randn(num_tokens, seq_len,
67+
self.impl.num_heads * self.impl.head_size)
68+
key = torch.randn(num_tokens, seq_len,
69+
self.impl.num_kv_heads * self.impl.head_size)
70+
value = torch.randn(num_tokens, seq_len,
71+
self.impl.num_kv_heads * self.impl.head_size)
72+
kv_cache = (torch.randn(num_blocks, block_size,
73+
self.impl.num_heads * self.impl.head_size),
74+
torch.randn(num_blocks, block_size,
75+
self.impl.num_heads * self.impl.head_size))
76+
output = torch.randn(num_tokens, self.impl.num_heads,
77+
self.impl.head_size)
78+
79+
decode = MagicMock() # TODO
80+
decode.seq_lens_list = [2] * num_tokens
81+
decode.block_table = torch.ones(num_tokens, 8, dtype=torch.int32)
82+
decode.attn_mask = None
83+
84+
metadata = MagicMock()
85+
metadata.attn_state = AscendAttentionState.DecodeOnly
86+
metadata.decode = decode
87+
88+
mock_fused.return_value = (torch.ones(num_tokens, self.impl.num_heads,
89+
self.impl.head_size),
90+
torch.ones(1))
91+
92+
result = self.impl.forward(layer, query, key, value, kv_cache,
93+
metadata, output, False)
94+
self.assertEqual(result.shape[0], num_tokens)

vllm_ascend/torchair/torchair_attention.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -436,17 +436,24 @@ def forward(
436436
block_size = key_cache.shape[1]
437437
query = query.view(num_tokens, 1,
438438
self.num_heads * self.head_size).contiguous()
439-
output = torch_npu.npu_incre_flash_attention(
440-
query,
441-
key_cache,
442-
value_cache,
443-
num_key_value_heads=self.num_kv_heads,
439+
output, _ = torch_npu.npu_fused_infer_attention_score(
440+
query=query,
441+
key=key_cache,
442+
value=value_cache,
443+
query_rope=None,
444+
key_rope=None,
444445
num_heads=self.num_heads,
445-
actual_seq_lengths=seq_lens,
446-
scale_value=self.scale,
447-
block_table=block_table,
446+
num_key_value_heads=self.num_kv_heads,
448447
input_layout='BSH',
449-
block_size=block_size)
448+
atten_mask=decode_meta.attn_mask,
449+
sparse_mode=0,
450+
scale=self.scale,
451+
antiquant_mode=0,
452+
antiquant_scale=None,
453+
block_table=block_table,
454+
block_size=block_size,
455+
actual_seq_lengths_kv=seq_lens,
456+
)
450457
else:
451458
raise NotImplementedError(
452459
"Torchair graph mode with non-MLA attention backend is still experimental."

0 commit comments

Comments
 (0)