Skip to content

Commit 7a852b7

Browse files
yiz-liumercykid
authored andcommitted
[Feat][Graph] Support FULL_DECODE_ONLY mode for GQA/MHA models (vllm-project#2128)
Note: This depends on [vLLM #25161](vllm-project/vllm#25161) and the torch\_npu release from September 30. ### What this PR does / why we need it? This pull request adds `FULL_DECODE_ONLY` mode for GQA/MHA models (MLA models like DeepSeek V3/R1 are not included). Key improvements include: * **Reduced dispatch latency:** By replaying the entire model execution graph at once, we cut overhead compared with multiple smaller replays. * **Stabilized multi-device performance:** Captureing the whole model as one static graph also mitigates the dispatch fluctuations across devices. * **Stream/resource savings:** Consolidating graph captures frees up streams, allowing more graphs to be captured. **Known issues:** 1. `_npu_paged_attention` currently manages its own workspace in `torch_npu`, which can deadlock when synchronizing during graph replay — we’re working on a fix. There may be other corner cases. This PR is the first in a planned series; we’ll continue to iterate and address remaining issues in follow-ups. This is essentially a port of vllm-project#1503 and vllm-project#1677, but includes two major changes: 1. Let `graph_dispatcher` decide the graph mode instead of hard-coding it in the backend, which decouples Full Graph and Piecewise Graph and could make it possible to remove dynamo. 2. Adapt to the new `attn_group` logic, but leave a small hack in `update_graph_params`; multi-attention models may or may not be fully supported yet. ### Does this PR introduce _any_ user-facing change? ```python compilation_config={ "cudagraph_mode": "FULL_DECODE_ONLY", }, ``` ### How was this patch tested? Tests included. - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@9607d5e --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Che Ruan <cr623@ic.ac.uk>
1 parent 60d3384 commit 7a852b7

File tree

14 files changed

+390
-91
lines changed

14 files changed

+390
-91
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
101101
max_query_len=5,
102102
decode_token_per_req=torch.tensor([1, 1]),
103103
block_table_tensor=torch.zeros((10, 10)),
104-
slot_mapping_cpu=torch.tensor(range(20)),
104+
slot_mapping=torch.tensor(range(20)),
105105
actual_seq_lengths_q=torch.tensor([0, 1]),
106106
positions=torch.tensor([10, 10]),
107107
attn_mask=torch.ones((10, 10)),
@@ -134,7 +134,7 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
134134
max_query_len=6,
135135
decode_token_per_req=torch.tensor([1, 1, 1]),
136136
block_table_tensor=torch.zeros((10, 10)),
137-
slot_mapping_cpu=torch.tensor(range(20)),
137+
slot_mapping=torch.tensor(range(20)),
138138
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
139139
positions=torch.tensor([10, 10]),
140140
attn_mask=torch.ones((15, 15)),
@@ -165,7 +165,7 @@ def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
165165
max_query_len=6,
166166
decode_token_per_req=torch.tensor([1, 1, 1]),
167167
block_table_tensor=torch.zeros((10, 10)),
168-
slot_mapping_cpu=torch.tensor(range(20)),
168+
slot_mapping=torch.tensor(range(20)),
169169
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
170170
positions=torch.tensor([10, 10]),
171171
attn_mask=torch.ones((15, 15)),
@@ -378,10 +378,12 @@ def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
378378
mock_flash_attention_qlens.assert_called_once()
379379
assert output.shape == (10, 8 * 64)
380380

381+
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
381382
@patch('torch_npu._npu_reshape_and_cache')
382383
@patch('torch_npu._npu_paged_attention')
383384
def test_forward_decode_only(self, mock_paged_attention,
384-
mock_npu_reshape_and_cache):
385+
mock_npu_reshape_and_cache,
386+
mock_get_forward_context):
385387
"""Test forward pass in DecodeOnly state"""
386388
query = torch.randn(10, 8 * 64)
387389
key = torch.randn(10, 8 * 64)
@@ -395,6 +397,8 @@ def test_forward_decode_only(self, mock_paged_attention,
395397
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
396398
layer = self.layer_no_quant
397399

400+
mock_get_forward_context.return_value = MagicMock(capturing=False)
401+
398402
output = self.impl.forward(layer,
399403
query,
400404
key,
@@ -435,12 +439,13 @@ def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
435439
mock_fused_infer_attention_score.assert_called_once()
436440
assert output.shape == (10, 8 * 64)
437441

442+
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
438443
@patch('torch_npu._npu_reshape_and_cache')
439444
@patch('torch_npu._npu_paged_attention')
440445
@patch('torch_npu.npu_fused_infer_attention_score')
441446
def test_forward_decode_only_swa_seq_len_mismatch(
442447
self, mock_fused_infer_attention_score, mock_paged_attention,
443-
mock_npu_reshape_and_cache):
448+
mock_npu_reshape_and_cache, mock_get_forward_context):
444449
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
445450
query = torch.randn(10, 8 * 64)
446451
key = torch.randn(10, 8 * 64)
@@ -457,6 +462,8 @@ def test_forward_decode_only_swa_seq_len_mismatch(
457462
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
458463
64), 1)
459464

465+
mock_get_forward_context.return_value = MagicMock(capturing=False)
466+
460467
output = self.impl_swa.forward(self.layer_no_quant,
461468
query,
462469
key,

tests/ut/torchair/test_torchair_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def test_build_decode(self, mock_ascend_config):
463463
max_query_len=1,
464464
decode_token_per_req=torch.tensor([1, 1, 1]),
465465
block_table_tensor=torch.zeros((10, 10)),
466-
slot_mapping_cpu=torch.tensor(range(20)),
466+
slot_mapping=torch.tensor(range(20)),
467467
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
468468
positions=torch.tensor([1, 1]),
469469
attn_mask=torch.ones((15, 15)),

vllm_ascend/attention/attention_v1.py

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@
3131
is_v1_kv_transfer_group)
3232
from vllm.forward_context import ForwardContext, get_forward_context
3333
from vllm.utils import cdiv, direct_register_custom_op
34+
from vllm.v1.attention.backends.utils import AttentionCGSupport
3435
from vllm.v1.core.sched.output import SchedulerOutput
3536
from vllm.v1.kv_cache_interface import AttentionSpec
3637

3738
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3839
from vllm_ascend.ops.attention import vanilla_chunked_prefill
39-
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
40-
nd_to_nz_2d, nd_to_nz_spec)
40+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
41+
get_graph_params, is_310p, nd_to_nz_2d,
42+
nd_to_nz_spec)
4143

4244

4345
def wait_for_kv_layer_from_connector(layer_name: str):
@@ -197,6 +199,12 @@ class AscendMetadata:
197199

198200

199201
class AscendAttentionMetadataBuilder:
202+
# Does this backend/builder support CUDA Graphs for attention (default: no).
203+
cudagraph_support: ClassVar[AttentionCGSupport] = \
204+
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
205+
# Does this backend/builder reorder the batch?
206+
# If not, set this to None. Otherwise set it to the query
207+
# length that will be pulled into the front of the batch.
200208
reorder_batch_threshold: ClassVar[int] = 1
201209

202210
def __init__(
@@ -221,7 +229,7 @@ def build(
221229
self,
222230
common_prefix_len: int,
223231
common_attn_metadata: AscendCommonAttentionMetadata,
224-
model: nn.Module,
232+
model: Optional[nn.Module] = None,
225233
):
226234
num_reqs = common_attn_metadata.num_reqs
227235
num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -231,11 +239,7 @@ def build(
231239
block_table = common_attn_metadata.block_table_tensor
232240
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
233241
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
234-
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
235-
num_actual_tokens].to(
236-
self.device,
237-
non_blocking=
238-
True)
242+
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
239243
attn_mask = common_attn_metadata.attn_mask
240244
attn_state = common_attn_metadata.attn_state
241245
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
@@ -268,6 +272,24 @@ def build(
268272
is_only_prefill=common_attn_metadata.is_only_prefill)
269273
return attn_metadata
270274

275+
def build_for_graph_capture(
276+
self,
277+
common_attn_metadata: AscendCommonAttentionMetadata,
278+
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
279+
):
280+
if attn_state == AscendAttentionState.DecodeOnly:
281+
attn_metadata = self.build(
282+
common_prefix_len=0,
283+
common_attn_metadata=common_attn_metadata,
284+
)
285+
else:
286+
raise NotImplementedError(
287+
"Currently we only support building dummy metadata for DecodeOnly state"
288+
)
289+
290+
attn_metadata.attn_state = attn_state
291+
return attn_metadata
292+
271293

272294
class AscendAttentionBackendImpl(AttentionImpl):
273295

@@ -406,16 +428,53 @@ def _forward_decode_only(
406428

407429
output = output.view(batch_size, self.num_heads, self.head_size)
408430
else:
409-
torch_npu._npu_paged_attention(
410-
query=query,
411-
key_cache=self.key_cache,
412-
value_cache=self.value_cache,
413-
num_kv_heads=self.num_kv_heads,
414-
num_heads=self.num_heads,
415-
scale_value=self.scale,
416-
block_table=attn_metadata.block_tables,
417-
context_lens=attn_metadata.seq_lens,
418-
out=output)
431+
graph_params = get_graph_params()
432+
forward_context: ForwardContext = get_forward_context()
433+
num_tokens = query.shape[0]
434+
if forward_context.capturing:
435+
stream = torch_npu.npu.current_stream()
436+
437+
event = torch.npu.ExternalEvent()
438+
event.wait(stream)
439+
event.reset(stream)
440+
graph_params.events[num_tokens].append(event)
441+
442+
graph_params.attn_params[num_tokens].append((
443+
query,
444+
self.key_cache,
445+
self.value_cache,
446+
self.num_kv_heads,
447+
self.num_heads,
448+
self.scale,
449+
attn_metadata.block_tables,
450+
attn_metadata.seq_lens,
451+
output,
452+
))
453+
454+
torch.npu.graph_task_group_begin(stream)
455+
torch_npu._npu_paged_attention(
456+
query=query,
457+
key_cache=self.key_cache,
458+
value_cache=self.value_cache,
459+
num_kv_heads=self.num_kv_heads,
460+
num_heads=self.num_heads,
461+
scale_value=self.scale,
462+
block_table=attn_metadata.block_tables,
463+
context_lens=attn_metadata.seq_lens,
464+
out=output)
465+
handle = torch.npu.graph_task_group_end(stream)
466+
graph_params.handles[num_tokens].append(handle)
467+
else:
468+
torch_npu._npu_paged_attention(
469+
query=query,
470+
key_cache=self.key_cache,
471+
value_cache=self.value_cache,
472+
num_kv_heads=self.num_kv_heads,
473+
num_heads=self.num_heads,
474+
scale_value=self.scale,
475+
block_table=attn_metadata.block_tables,
476+
context_lens=attn_metadata.seq_lens,
477+
out=output)
419478
return output
420479

421480
def _forward_v1_style(

vllm_ascend/attention/mla_v1.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,7 @@ def build(
292292
device = self.device
293293

294294
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
295-
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
296-
num_actual_tokens].to(
297-
device,
298-
non_blocking=
299-
True)
295+
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
300296
input_positions = common_attn_metadata.positions[:
301297
num_actual_tokens].long(
302298
)

vllm_ascend/attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class AscendCommonAttentionMetadata:
4141

4242
block_table_tensor: torch.Tensor
4343

44-
slot_mapping_cpu: torch.Tensor
44+
slot_mapping: torch.Tensor
4545

4646
actual_seq_lengths_q: list[int]
4747

vllm_ascend/compilation/acl_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __call__(self, *args, **kwargs):
147147
patch("torch.npu.empty_cache", lambda: None))
148148

149149
# mind-exploding: carefully manage the reference and memory.
150+
forward_context.capturing = True
150151
with torch.npu.graph(aclgraph, pool=self.graph_pool):
151152
# `output` is managed by pytorch's aclgraph pool
152153
output = self.runnable(*args, **kwargs)

vllm_ascend/platform.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,23 +179,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
179179

180180
compilation_config.cudagraph_num_of_warmups = 1
181181

182-
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
183-
# if cudagraph_mode is not explicitly set by users, set default value
184-
if compilation_config.level == CompilationLevel.PIECEWISE:
185-
compilation_config.cudagraph_mode = \
186-
CUDAGraphMode.PIECEWISE
187-
elif compilation_config.level not in [
182+
if compilation_config.level not in [
188183
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
189184
]:
190185
logger.warning(
191186
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
192187
compilation_config.level)
193188
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
194-
else:
195-
logger.warning(
196-
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
197-
)
198-
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
199189

200190
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
201191
if ascend_config.torchair_graph_config.enabled:
@@ -221,7 +211,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
221211

222212
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
223213
compilation_config.level = CompilationLevel.NO_COMPILATION
224-
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
214+
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
215+
# after MLA being supported
216+
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
217+
compilation_config.cudagraph_mode
218+
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
219+
and model_config.use_mla):
225220
logger.info(
226221
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
227222
"using only ACL Graph mode")
@@ -233,6 +228,24 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
233228
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
234229
])
235230
update_aclgraph_sizes(vllm_config)
231+
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
232+
logger.info(
233+
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
234+
"using only ACL Graph mode")
235+
compilation_config.use_inductor = False
236+
warning_message = """\033[91m
237+
**********************************************************************************
238+
* WARNING: You have enabled the *full graph* feature.
239+
* This is an early experimental stage and may involve various unknown issues.
240+
* A known problem is that capturing too many batch sizes can lead to OOM
241+
* (Out of Memory) errors or inference hangs. If you encounter such issues,
242+
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
243+
* batch size for graph capture.
244+
* For more details, please refer to:
245+
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
246+
**********************************************************************************\033[0m
247+
"""
248+
logger.warning(warning_message)
236249
else:
237250
logger.info(
238251
"%s cudagraph_mode is not support on NPU. falling back to NONE",
@@ -379,3 +392,7 @@ def stateless_init_device_torch_dist_pg(
379392
@classmethod
380393
def support_hybrid_kv_cache(cls) -> bool:
381394
return True
395+
396+
@classmethod
397+
def support_static_graph_mode(cls) -> bool:
398+
return True

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def _get_eagle_atten_dict(
347347
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
348348
block_table_tensor=self.runner.input_batch.block_table[0].
349349
get_device_tensor(),
350-
slot_mapping_cpu=self.runner.slot_mapping_cpu,
350+
slot_mapping=self.runner.slot_mapping,
351351
positions=self.runner.positions,
352352
attn_mask=self.runner.attn_mask,
353353
spec_attn_mask=self.runner.spec_attn_mask,
@@ -434,7 +434,7 @@ def _propose(
434434
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
435435
block_table_tensor=self.runner.input_batch.block_table[0].
436436
get_device_tensor(),
437-
slot_mapping_cpu=target_slot_mapping,
437+
slot_mapping=target_slot_mapping,
438438
positions=target_positions,
439439
attn_mask=self.runner.attn_mask,
440440
spec_attn_mask=self.runner.spec_attn_mask,

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def _propose(
385385
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
386386
block_table_tensor=self.runner.input_batch.block_table[0].
387387
get_device_tensor(),
388-
slot_mapping_cpu=target_slot_mapping,
388+
slot_mapping=target_slot_mapping,
389389
positions=target_positions,
390390
attn_mask=self.runner.attn_mask,
391391
spec_attn_mask=self.runner.spec_attn_mask,

vllm_ascend/torchair/torchair_attention.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def build(
175175
self,
176176
common_prefix_len: int,
177177
common_attn_metadata: AscendCommonAttentionMetadata,
178-
model: nn.Module,
178+
model: Optional[nn.Module] = None,
179179
):
180180
num_reqs = common_attn_metadata.num_reqs
181181
num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -185,11 +185,7 @@ def build(
185185
block_table[:num_reqs])
186186

187187
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
188-
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
189-
num_actual_tokens].to(
190-
self.device,
191-
non_blocking=
192-
True)
188+
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
193189
attn_mask = common_attn_metadata.attn_mask
194190

195191
attn_state = common_attn_metadata.attn_state

0 commit comments

Comments
 (0)