Skip to content

Commit e9b4221

Browse files
committed
refactor: Move graph parameter logic to acl_graph module
Moves graph parameter management components, including `GraphParams`, `get_graph_params`, and `set_graph_params`, from the generic `utils.py` to the more specific `compilation/acl_graph.py`. Additionally, extracts the `update_attn_params` logic from the `NPUModelRunner` class into a standalone function within the `acl_graph` module. This refactoring improves code organization by centralizing ACL graph-related logic into its own dedicated module, enhancing modularity and clarity. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent bb1f0d5 commit e9b4221

File tree

4 files changed

+84
-81
lines changed

4 files changed

+84
-81
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636
from vllm.v1.kv_cache_interface import AttentionSpec
3737

3838
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
39+
from vllm_ascend.compilation.acl_graph import get_graph_params
3940
from vllm_ascend.ops.attention import vanilla_chunked_prefill
40-
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
41-
get_graph_params, is_310p, nd_to_nz_2d,
42-
nd_to_nz_spec)
41+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
42+
nd_to_nz_2d, nd_to_nz_spec)
4343

4444

4545
def wait_for_kv_layer_from_connector(layer_name: str):

vllm_ascend/compilation/acl_graph.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33

44
import dataclasses
55
from contextlib import ExitStack
6+
from dataclasses import dataclass
67
from typing import Any, Callable, Optional
78
from unittest.mock import patch
89

910
import torch
11+
import torch_npu
1012
import vllm.envs as envs
1113
from vllm.compilation.counter import compilation_counter
1214
from vllm.compilation.cuda_graph import CUDAGraphOptions
@@ -185,3 +187,74 @@ def __call__(self, *args, **kwargs):
185187
logger.info_once("Replaying aclgraph")
186188
entry.aclgraph.replay()
187189
return entry.output
190+
191+
192+
def update_attn_params(update_stream, forward_context, runtime_shape):
193+
graph_params = get_graph_params()
194+
# FIXME: Behold! We are using a temporary hack here to update the args
195+
# for each layer's attention op in the graph.
196+
for key, param, handle, event in zip(
197+
forward_context.attn_metadata,
198+
graph_params.attn_params[runtime_shape],
199+
graph_params.handles[runtime_shape],
200+
graph_params.events[runtime_shape],
201+
):
202+
(
203+
query,
204+
key_cache,
205+
value_cache,
206+
num_kv_heads,
207+
num_heads,
208+
scale,
209+
block_table,
210+
seq_lens,
211+
output,
212+
) = param
213+
# block_table = forward_context.attn_metadata[key].block_tables
214+
seq_lens = forward_context.attn_metadata[key].seq_lens
215+
216+
with torch.npu.stream(update_stream):
217+
torch.npu.graph_task_update_begin(update_stream, handle)
218+
torch_npu._npu_paged_attention(query=query,
219+
key_cache=key_cache,
220+
value_cache=value_cache,
221+
num_kv_heads=num_kv_heads,
222+
num_heads=num_heads,
223+
scale_value=scale,
224+
block_table=block_table,
225+
context_lens=seq_lens,
226+
out=output)
227+
torch.npu.graph_task_update_end(update_stream)
228+
229+
event.record(update_stream)
230+
231+
232+
@dataclass
233+
class GraphParams:
234+
events: dict[int, list[torch.npu.ExternalEvent]]
235+
workspaces: dict[int, torch.Tensor]
236+
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
237+
attn_params: dict[int, list[tuple]]
238+
239+
240+
_graph_params: Optional[GraphParams] = None
241+
242+
243+
def set_graph_params(aclgraph_capture_sizes: set[int]):
244+
global _graph_params
245+
if _graph_params is not None:
246+
raise ValueError("Graph parameters have already been set!")
247+
_graph_params = GraphParams(
248+
{size: []
249+
for size in aclgraph_capture_sizes},
250+
{size: None
251+
for size in aclgraph_capture_sizes},
252+
{size: []
253+
for size in aclgraph_capture_sizes},
254+
{size: []
255+
for size in aclgraph_capture_sizes},
256+
)
257+
258+
259+
def get_graph_params():
260+
return _graph_params

vllm_ascend/utils.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@
2222
import math
2323
import os
2424
from contextlib import contextmanager, nullcontext
25-
from dataclasses import dataclass
2625
from enum import Enum
2726
from threading import Lock
2827
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
2928

3029
import torch
31-
import torch_npu # noqa: F401 # noqa: F401
30+
import torch_npu # noqa: F401
3231
from packaging.version import InvalidVersion, Version
3332
from torch_npu.npu.streams import Event
3433
from vllm.logger import logger
@@ -635,34 +634,3 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
635634
return nullcontext()
636635
assert target_stream is not None
637636
return torch.npu.stream(target_stream)
638-
639-
640-
@dataclass
641-
class GraphParams:
642-
events: dict[int, list[torch.npu.ExternalEvent]]
643-
workspaces: dict[int, torch.Tensor]
644-
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
645-
attn_params: dict[int, list[tuple]]
646-
647-
648-
_graph_params: Optional[GraphParams] = None
649-
650-
651-
def set_graph_params(aclgraph_capture_sizes: set[int]):
652-
global _graph_params
653-
if _graph_params is not None:
654-
raise ValueError("Graph parameters have already been set!")
655-
_graph_params = GraphParams(
656-
{size: []
657-
for size in aclgraph_capture_sizes},
658-
{size: None
659-
for size in aclgraph_capture_sizes},
660-
{size: []
661-
for size in aclgraph_capture_sizes},
662-
{size: []
663-
for size in aclgraph_capture_sizes},
664-
)
665-
666-
667-
def get_graph_params():
668-
return _graph_params

vllm_ascend/worker/model_runner_v1.py

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@
9898
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
9999
from vllm_ascend.attention.attention_v1 import AscendAttentionState
100100
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
101-
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
101+
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
102+
set_graph_params,
103+
update_attn_params)
102104
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
103105
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
104106
D2DExpertWeightLoader
@@ -116,9 +118,8 @@
116118
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
117119
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
118120
AscendSocVersion, ProfileExecuteDuration,
119-
get_ascend_soc_version, get_graph_params,
120-
is_310p, lmhead_tp_enable, set_graph_params,
121-
vllm_version_is)
121+
get_ascend_soc_version, is_310p,
122+
lmhead_tp_enable, vllm_version_is)
122123
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
123124

124125
if TYPE_CHECKING:
@@ -1570,9 +1571,8 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
15701571

15711572
forward_context = get_forward_context()
15721573
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
1573-
graph_params = get_graph_params()
1574-
self.update_attn_params(graph_params, forward_context,
1575-
positions.shape[0])
1574+
update_attn_params(self.update_stream, forward_context,
1575+
positions.shape[0])
15761576

15771577
if get_forward_context().flashcomm_v1_enabled:
15781578
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -1581,44 +1581,6 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
15811581
hidden_states = hidden_states[:-pad_size, :]
15821582
return hidden_states
15831583

1584-
def update_attn_params(self, graph_params, forward_context, runtime_shape):
1585-
# FIXME: Behold! We are using a temporary hack here to update the args
1586-
# for each layer's attention op in the graph.
1587-
for key, param, handle, event in zip(
1588-
forward_context.attn_metadata,
1589-
graph_params.attn_params[runtime_shape],
1590-
graph_params.handles[runtime_shape],
1591-
graph_params.events[runtime_shape],
1592-
):
1593-
(
1594-
query,
1595-
key_cache,
1596-
value_cache,
1597-
num_kv_heads,
1598-
num_heads,
1599-
scale,
1600-
block_table,
1601-
seq_lens,
1602-
output,
1603-
) = param
1604-
# block_table = forward_context.attn_metadata[key].block_tables
1605-
seq_lens = forward_context.attn_metadata[key].seq_lens
1606-
1607-
with torch.npu.stream(self.update_stream):
1608-
torch.npu.graph_task_update_begin(self.update_stream, handle)
1609-
torch_npu._npu_paged_attention(query=query,
1610-
key_cache=key_cache,
1611-
value_cache=value_cache,
1612-
num_kv_heads=num_kv_heads,
1613-
num_heads=num_heads,
1614-
scale_value=scale,
1615-
block_table=block_table,
1616-
context_lens=seq_lens,
1617-
out=output)
1618-
torch.npu.graph_task_update_end(self.update_stream)
1619-
1620-
event.record(self.update_stream)
1621-
16221584
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
16231585
num_valid_tokens):
16241586
ascend_config = get_ascend_config()

0 commit comments

Comments
 (0)