Skip to content

Commit 6681dde

Browse files
authored
[Feat][Graph] Support MTP for ACL Graph (#2932)
### What this PR does / why we need it? This PR depends on the merge of #2707 and has adapted the aclgraph functionality to support MTP. ### How was this patch tested? - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@2b85697 --------- Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent cef43b5 commit 6681dde

File tree

7 files changed

+73
-11
lines changed

7 files changed

+73
-11
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def mtp_correctness(
3939
tensor_parallel_size=1,
4040
gpu_memory_utilization=0.7,
4141
max_model_len=256,
42-
enforce_eager=True) as ref_llm:
42+
enforce_eager=False) as ref_llm:
4343
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
4444

4545
with VllmRunner(
@@ -53,7 +53,7 @@ def mtp_correctness(
5353
"method": "deepseek_mtp",
5454
"num_speculative_tokens": num_speculative_tokens,
5555
},
56-
enforce_eager=True,
56+
enforce_eager=False,
5757
max_model_len=2000,
5858
additional_config={"ascend_scheduler_config": {
5959
"enabled": False

tests/ut/attention/test_mla_v1.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,34 @@ def test_ascend_mla_metadata_builder_default(self):
186186
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
187187
mock_device = 'cpu'
188188

189+
mock_vllm_config.speculative_config = None
190+
191+
ascend_config = MagicMock()
192+
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
193+
return_value=ascend_config):
194+
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
195+
mock_device)
196+
197+
self.assertEqual(builder.block_size,
198+
mock_vllm_config.cache_config.block_size)
199+
self.assertEqual(
200+
builder.chunked_prefill_enabled,
201+
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
202+
203+
def test_ascend_mla_metadata_builder_spec_decode(self):
204+
mock_vllm_config = MagicMock()
205+
mock_vllm_config.model_config.max_model_len = 1024
206+
mock_vllm_config.model_config.get_head_size.return_value = 64
207+
mock_vllm_config.model_config.dtype = torch.float16
208+
mock_vllm_config.cache_config.block_size = 16
209+
mock_vllm_config.scheduler_config.max_num_seqs = 4
210+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
211+
mock_device = 'cpu'
212+
213+
mock_spec_config = MagicMock()
214+
mock_spec_config.num_speculative_tokens = 3
215+
mock_vllm_config.speculative_config = mock_spec_config
216+
189217
ascend_config = MagicMock()
190218
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
191219
return_value=ascend_config):
@@ -208,6 +236,8 @@ def test_reorder_batch(self):
208236
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
209237
mock_device = 'cpu'
210238

239+
mock_vllm_config.speculative_config = None
240+
211241
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
212242
return_value=ascend_config):
213243
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,

tests/ut/torchair/test_torchair_mla.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def test_ascend_mla_metadata_builder_default(self):
190190
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
191191
mock_device = 'cpu'
192192

193+
mock_vllm_config.speculative_config = None
194+
193195
ascend_config = MagicMock()
194196
ascend_config.torchair_graph_config = MagicMock()
195197
ascend_config.torchair_graph_config.enabled = True
@@ -217,6 +219,8 @@ def test_reorder_batch_with_torchair_graph(self, ascend_config):
217219
ascend_config.torchair_graph_config = MagicMock()
218220
ascend_config.torchair_graph_config.enabled = True
219221

222+
mock_vllm_config.speculative_config = None
223+
220224
builder = AscendMLATorchairMetadataBuilder(None, None,
221225
mock_vllm_config,
222226
mock_device)
@@ -252,6 +256,8 @@ def test_reorder_batch_without_torchair_graph(self):
252256
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
253257
mock_device = 'cpu'
254258

259+
mock_vllm_config.speculative_config = None
260+
255261
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
256262
return_value=ascend_config):
257263
builder = AscendMLATorchairMetadataBuilder(None, None,
@@ -288,6 +294,8 @@ def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
288294
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
289295
mock_device = 'cpu'
290296

297+
mock_vllm_config.speculative_config = None
298+
291299
builder = AscendMLATorchairMetadataBuilder(None, None,
292300
mock_vllm_config,
293301
mock_device)
@@ -309,6 +317,8 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
309317
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
310318
mock_device = 'cpu'
311319

320+
mock_vllm_config.speculative_config = None
321+
312322
builder = AscendMLATorchairMetadataBuilder(None, None,
313323
mock_vllm_config,
314324
mock_device)
@@ -331,6 +341,8 @@ def test_get_graph_runner_block_tables_from_numpy(self,
331341
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
332342
mock_device = 'cpu'
333343

344+
mock_vllm_config.speculative_config = None
345+
334346
builder = AscendMLATorchairMetadataBuilder(None, None,
335347
mock_vllm_config,
336348
mock_device)
@@ -357,6 +369,8 @@ def test_build_dummy(self, mock_ascend_config):
357369
mock_vllm_config.model_config.dtype = torch.float16
358370
mock_device = 'cpu'
359371

372+
mock_vllm_config.speculative_config = None
373+
360374
builder = AscendMLATorchairMetadataBuilder(
361375
None,
362376
None,
@@ -424,6 +438,8 @@ def test_build_decode(self, mock_ascend_config):
424438
model = MagicMock(spec=nn.Module)
425439
model.model = MagicMock(spec=nn.Module)
426440

441+
mock_vllm_config.speculative_config = None
442+
427443
builder = AscendMLATorchairMetadataBuilder(
428444
None,
429445
None,

vllm_ascend/attention/mla_v1.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,14 @@ def __init__(self,
187187
self.block_size - 1) // self.block_size
188188
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
189189

190+
self.speculative_config = vllm_config.speculative_config
190191
self.decode_threshold = 1
192+
if self.speculative_config:
193+
spec_token_num = self.speculative_config.num_speculative_tokens
194+
self.decode_threshold += spec_token_num
195+
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
196+
npu_fused_infer_attention_score TND layout's limit of 16, \
197+
got {self.decode_threshold}"
191198

192199
if self.chunked_prefill_enabled:
193200
self.chunked_prefill_workspace_size = min(
@@ -275,7 +282,6 @@ def build(
275282
num_actual_tokens = common_attn_metadata.num_actual_tokens
276283
query_start_loc = common_attn_metadata.query_start_loc
277284
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
278-
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
279285
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
280286
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
281287
assert num_decodes + num_prefills == num_reqs

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from vllm.attention.layer import Attention
99
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
1010
set_current_vllm_config)
11-
from vllm.forward_context import get_forward_context
11+
from vllm.forward_context import BatchDescriptor, get_forward_context
1212
from vllm.model_executor.model_loader import get_model_loader
1313
from vllm.model_executor.model_loader.utils import (
1414
process_weights_after_loading, set_default_torch_dtype)
@@ -363,8 +363,14 @@ def _propose(
363363
not self.runner.with_prefill
364364

365365
if is_running_torchair:
366+
# Torchair graph mode, padding is same as the main model
366367
num_input_tokens = self.runner.graph_pad_size
368+
elif (self.runner.use_aclgraph
369+
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
370+
# Acl graph mode, add padding to the batch size
371+
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
367372
else:
373+
# Eager mode, no padding needed
368374
num_input_tokens = num_tokens
369375

370376
seq_lens = target_positions[last_token_indices] + 1
@@ -410,14 +416,18 @@ def _propose(
410416
# TODO: adapt enable_dbo later
411417
(num_input_tokens, num_tokens_across_dp, with_prefill,
412418
_) = self.runner._sync_metadata_across_dp(
413-
num_tokens, self.runner.with_prefill, False)
419+
num_input_tokens, self.runner.with_prefill, False)
414420
else:
415421
# torchair mode can reuse self.runner.num_tokens_across_dp
416422
num_tokens_across_dp = self.runner.num_tokens_across_dp
417423
with_prefill = self.runner.with_prefill
418424

419425
moe_comm_method = self.runner._select_moe_comm_method(
420426
num_input_tokens, with_prefill)
427+
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
428+
uniform_decode=False)
429+
aclgraph_runtime_mode, batch_descriptor = \
430+
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
421431

422432
for step in range(self.num_speculative_tokens):
423433
with set_ascend_forward_context(
@@ -428,6 +438,7 @@ def _propose(
428438
num_tokens_across_dp=num_tokens_across_dp,
429439
reserved_mc2_mask=self.runner.reserved_mc2_mask,
430440
moe_comm_method=moe_comm_method,
441+
aclgraph_runtime_mode=aclgraph_runtime_mode,
431442
in_profile_run=self.runner.in_profile_run,
432443
num_actual_tokens=num_tokens):
433444
with ProfileExecuteDuration().capture_async('mtp_forward'):

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
5252
ascend_config = get_ascend_config()
5353
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
5454
super().__init__(vllm_config, device)
55+
if self.speculative_config:
56+
self.actual_seq_lengths_q = list(
57+
range(self.decode_token_per_req, self.max_num_tokens + 1,
58+
self.decode_token_per_req))
5559
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
5660
None, None, vllm_config, device)
5761

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,17 +306,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
306306
self.spec_attn_mask = None
307307
self.drafter: Optional[Union[NgramProposer, EagleProposer,
308308
MtpProposer]] = None
309-
self.actual_seq_lengths_q = []
309+
self.actual_seq_lengths_q: list[int] = []
310310
self.decode_token_per_req = 1
311311
if self.speculative_config:
312312
spec_token_num = self.speculative_config.num_speculative_tokens
313313
assert spec_token_num > 0
314314
self.decode_token_per_req = 1 + spec_token_num
315-
self.actual_seq_lengths_q = [
316-
len for len in
317-
range(self.decode_token_per_req, self.max_num_tokens +
318-
1, self.decode_token_per_req)
319-
]
320315
self.spec_attn_mask = torch.triu(torch.ones(2048,
321316
2048,
322317
dtype=torch.bool),

0 commit comments

Comments
 (0)