Skip to content

Commit e4af8ce

Browse files
author
wangxiaoxin-sherie
committed
add pagedattention to support fullgraph.
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
1 parent 704467c commit e4af8ce

File tree

3 files changed

+251
-12
lines changed

3 files changed

+251
-12
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 219 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import MagicMock, patch
1+
from unittest.mock import MagicMock, patch, call
22

33
import torch
44

@@ -409,6 +409,224 @@ def test_forward_decode_only(self, mock_paged_attention,
409409

410410
mock_paged_attention.assert_called_once()
411411
assert output.shape == (10, 8 * 64)
412+
413+
@patch('torch_npu._npu_paged_attention')
414+
@patch('torch.npu.graph_task_group_end')
415+
@patch('torch.npu.graph_task_group_begin')
416+
@patch('torch.npu.ExternalEvent')
417+
@patch('torch_npu.npu.current_stream')
418+
@patch('torch_npu._npu_paged_attention_get_workspace')
419+
def test_paged_attention_with_existing_workspace(self, mock_get_workspace, mock_current_stream,
420+
mock_external_event_class, mock_graph_begin,
421+
mock_graph_end, mock_paged_attention):
422+
"""
423+
测试工作区已存在的情况。
424+
"""
425+
# --- 设置 Mock 对象 ---
426+
# 模拟 graph_params 和 attn_metadata
427+
graph_params = MagicMock()
428+
attn_metadata = MagicMock()
429+
num_tokens = 10
430+
431+
# 模拟已存在的 workspace
432+
graph_params.workspaces = {num_tokens: 10}
433+
graph_params.events = {num_tokens: []}
434+
graph_params.attn_params = {num_tokens: []}
435+
graph_params.handles = {num_tokens: []}
436+
437+
# 模拟其他参数
438+
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
439+
key_cache = MagicMock()
440+
value_cache = MagicMock()
441+
num_kv_heads = 4
442+
num_heads = 8
443+
scale = 0.1
444+
block_tables = MagicMock()
445+
seq_lens = MagicMock()
446+
output = torch.randn(2, 5, 8)
447+
448+
# self 对象 (模拟类实例)
449+
self_obj = MagicMock()
450+
self_obj.key_cache = key_cache
451+
self_obj.value_cache = value_cache
452+
self_obj.num_kv_heads = num_kv_heads
453+
self_obj.num_heads = num_heads
454+
self_obj.scale = scale
455+
456+
# Mock 流和事件
457+
mock_stream = MagicMock()
458+
mock_current_stream.return_value = mock_stream
459+
mock_event_instance = MagicMock()
460+
mock_external_event_class.return_value = mock_event_instance
461+
462+
# Mock graph_task_group_end 返回句柄
463+
mock_handle = MagicMock()
464+
mock_graph_end.return_value = mock_handle
465+
466+
# --- 调用被测试的代码逻辑 (模拟) ---
467+
# 注意:这里直接模拟了代码逻辑,因为在实际应用中,这段代码可能在某个函数或方法内部
468+
# 为了测试,我们需要将其封装或直接在这里模拟执行流程
469+
470+
# 1. 获取 workspace
471+
workspace = graph_params.workspaces.get(num_tokens)
472+
# 断言:工作区已存在,不应调用 get_workspace
473+
mock_get_workspace.assert_not_called()
474+
self.assertEqual(workspace, 10)
475+
476+
# 2. Handle graph capturing mode
477+
stream = mock_current_stream()
478+
event = mock_external_event_class()
479+
event.wait(stream)
480+
event.reset(stream)
481+
graph_params.events[num_tokens].append(event)
482+
graph_params.attn_params[num_tokens].append((
483+
query,
484+
self_obj.key_cache,
485+
self_obj.value_cache,
486+
self_obj.num_kv_heads,
487+
self_obj.num_heads,
488+
self_obj.scale,
489+
attn_metadata.block_tables,
490+
attn_metadata.seq_lens,
491+
output,
492+
))
493+
494+
# 断言事件调用
495+
mock_event_instance.wait.assert_called_once_with(mock_stream)
496+
mock_event_instance.reset.assert_called_once_with(mock_stream)
497+
self.assertEqual(len(graph_params.events[num_tokens]), 1)
498+
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
499+
500+
# 3. Execute graph task
501+
mock_graph_begin(stream)
502+
mock_paged_attention(
503+
query=query,
504+
key_cache=self_obj.key_cache,
505+
value_cache=self_obj.value_cache,
506+
num_kv_heads=self_obj.num_kv_heads,
507+
num_heads=self_obj.num_heads,
508+
scale_value=self_obj.scale,
509+
block_table=attn_metadata.block_tables,
510+
context_lens=attn_metadata.seq_lens,
511+
out=output,
512+
workspace=workspace
513+
)
514+
handle = mock_graph_end(stream)
515+
516+
# 断言图任务调用
517+
mock_graph_begin.assert_called_once_with(mock_stream)
518+
mock_graph_end.assert_called_once_with(mock_stream)
519+
self.assertEqual(handle, mock_handle)
520+
self.assertEqual(len(graph_params.handles[num_tokens]), 1)
521+
522+
523+
@patch('torch_npu._npu_paged_attention')
524+
@patch('torch.npu.graph_task_group_end')
525+
@patch('torch.npu.graph_task_group_begin')
526+
@patch('torch.npu.ExternalEvent')
527+
@patch('torch_npu.npu.current_stream')
528+
@patch('torch_npu._npu_paged_attention_get_workspace')
529+
def test_paged_attention_with_new_workspace(self, mock_get_workspace, mock_current_stream,
530+
mock_external_event_class, mock_graph_begin,
531+
mock_graph_end, mock_paged_attention):
532+
"""
533+
测试工作区不存在,需要创建的情况。
534+
"""
535+
# --- 设置 Mock 对象 ---
536+
graph_params = MagicMock()
537+
attn_metadata = MagicMock()
538+
num_tokens = 15
539+
540+
# 模拟不存在的 workspace
541+
graph_params.workspaces = {}
542+
graph_params.events = {num_tokens: []}
543+
graph_params.attn_params = {num_tokens: []}
544+
graph_params.handles = {num_tokens: []}
545+
546+
547+
query = torch.randn(1, 3, 16)
548+
key_cache = MagicMock()
549+
value_cache = MagicMock()
550+
num_kv_heads = 2
551+
num_heads = 4
552+
scale = 0.2
553+
block_tables = MagicMock()
554+
seq_lens = MagicMock()
555+
output = torch.randn(1, 3, 16)
556+
557+
self_obj = MagicMock()
558+
self_obj.key_cache = key_cache
559+
self_obj.value_cache = value_cache
560+
self_obj.num_kv_heads = num_kv_heads
561+
self_obj.num_heads = num_heads
562+
self_obj.scale = scale
563+
564+
mock_stream = MagicMock()
565+
mock_current_stream.return_value = mock_stream
566+
mock_event_instance = MagicMock()
567+
mock_external_event_class.return_value = mock_event_instance
568+
569+
# 模拟创建新的 workspace
570+
new_workspace = MagicMock()
571+
mock_get_workspace.return_value = new_workspace
572+
573+
mock_handle = MagicMock()
574+
mock_graph_end.return_value = mock_handle
575+
576+
graph_params.workspaces[num_tokens] = 10
577+
578+
# Handle graph capturing mode (同上,简化)
579+
stream = mock_current_stream()
580+
event = mock_external_event_class()
581+
event.wait(stream)
582+
event.reset(stream)
583+
graph_params.events[num_tokens].append(event)
584+
graph_params.attn_params[num_tokens].append((
585+
query,
586+
self_obj.key_cache,
587+
self_obj.value_cache,
588+
self_obj.num_kv_heads,
589+
self_obj.num_heads,
590+
self_obj.scale,
591+
attn_metadata.block_tables,
592+
attn_metadata.seq_lens,
593+
output,
594+
))
595+
596+
# Execute graph task (同上,简化)
597+
mock_graph_begin(stream)
598+
mock_paged_attention(
599+
query=query,
600+
key_cache=self_obj.key_cache,
601+
value_cache=self_obj.value_cache,
602+
num_kv_heads=self_obj.num_kv_heads,
603+
num_heads=self_obj.num_heads,
604+
scale_value=self_obj.scale,
605+
block_table=attn_metadata.block_tables,
606+
context_lens=attn_metadata.seq_lens,
607+
out=output,
608+
workspace=10 # 使用新创建的 workspace
609+
)
610+
handle = mock_graph_end(stream)
611+
612+
# 断言图任务调用 (同上,简化)
613+
mock_graph_begin.assert_called_once_with(mock_stream)
614+
expected_paged_attn_call = call(
615+
query=query,
616+
key_cache=key_cache,
617+
value_cache=value_cache,
618+
num_kv_heads=num_kv_heads,
619+
num_heads=num_heads,
620+
scale_value=scale,
621+
block_table=block_tables,
622+
context_lens=seq_lens,
623+
out=output,
624+
workspace=new_workspace # 验证使用了新 workspace
625+
)
626+
mock_paged_attention.assert_called_once_with(**expected_paged_attn_call.kwargs)
627+
mock_graph_end.assert_called_once_with(mock_stream)
628+
self.assertEqual(handle, mock_handle)
629+
self.assertEqual(len(graph_params.handles[num_tokens]), 1)
412630

413631
@patch('torch_npu._npu_reshape_and_cache')
414632
@patch('torch_npu.npu_fused_infer_attention_score')

vllm_ascend/attention/attention_v1.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,13 +432,31 @@ def _forward_decode_only(
432432
forward_context: ForwardContext = get_forward_context()
433433
num_tokens = query.shape[0]
434434
if forward_context.capturing:
435+
# Prepare tensors for attention output
436+
# TODO: Refactor this to step-level instead of layer-level
437+
438+
# Get workspace from cache or calculate it if not present.
439+
workspace = graph_params.workspaces.get(num_tokens)
440+
if workspace is None:
441+
workspace = torch_npu._npu_paged_attention_get_workspace(
442+
query=query,
443+
key_cache=self.key_cache,
444+
value_cache=self.value_cache,
445+
num_kv_heads=self.num_kv_heads,
446+
num_heads=self.num_heads,
447+
scale_value=self.scale,
448+
block_table=attn_metadata.block_tables,
449+
context_lens=attn_metadata.seq_lens,
450+
out=output)
451+
graph_params.workspaces[num_tokens] = workspace
452+
453+
# Handle graph capturing mode
435454
stream = torch_npu.npu.current_stream()
436455

437456
event = torch.npu.ExternalEvent()
438457
event.wait(stream)
439458
event.reset(stream)
440459
graph_params.events[num_tokens].append(event)
441-
442460
graph_params.attn_params[num_tokens].append((
443461
query,
444462
self.key_cache,
@@ -461,7 +479,8 @@ def _forward_decode_only(
461479
scale_value=self.scale,
462480
block_table=attn_metadata.block_tables,
463481
context_lens=attn_metadata.seq_lens,
464-
out=output)
482+
out=output,
483+
workspace=workspace)
465484
handle = torch.npu.graph_task_group_end(stream)
466485
graph_params.handles[num_tokens].append(handle)
467486
else:

vllm_ascend/compilation/acl_graph.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,17 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
215215

216216
with torch.npu.stream(update_stream):
217217
torch.npu.graph_task_update_begin(update_stream, handle)
218-
torch_npu._npu_paged_attention(query=query,
219-
key_cache=key_cache,
220-
value_cache=value_cache,
221-
num_kv_heads=num_kv_heads,
222-
num_heads=num_heads,
223-
scale_value=scale,
224-
block_table=block_table,
225-
context_lens=seq_lens,
226-
out=output)
218+
torch_npu._npu_paged_attention(
219+
query=query,
220+
key_cache=key_cache,
221+
value_cache=value_cache,
222+
num_kv_heads=num_kv_heads,
223+
num_heads=num_heads,
224+
scale_value=scale,
225+
block_table=block_table,
226+
context_lens=seq_lens,
227+
out=output,
228+
workspace=graph_params.workspaces.get(runtime_shape))
227229
torch.npu.graph_task_update_end(update_stream)
228230

229231
event.record(update_stream)

0 commit comments

Comments
 (0)