8
8
from vllm .attention .layer import Attention
9
9
from vllm .config import (VllmConfig , get_layers_from_vllm_config ,
10
10
set_current_vllm_config )
11
- from vllm .forward_context import get_forward_context
11
+ from vllm .forward_context import BatchDescriptor , get_forward_context
12
12
from vllm .model_executor .model_loader import get_model_loader
13
13
from vllm .model_executor .model_loader .utils import (
14
14
process_weights_after_loading , set_default_torch_dtype )
@@ -363,8 +363,14 @@ def _propose(
363
363
not self .runner .with_prefill
364
364
365
365
if is_running_torchair :
366
+ # Torchair graph mode, padding is same as the main model
366
367
num_input_tokens = self .runner .graph_pad_size
368
+ elif (self .runner .use_aclgraph
369
+ and num_tokens <= self .runner .aclgraph_batch_sizes [- 1 ]):
370
+ # Acl graph mode, add padding to the batch size
371
+ num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
367
372
else :
373
+ # Eager mode, no padding needed
368
374
num_input_tokens = num_tokens
369
375
370
376
seq_lens = target_positions [last_token_indices ] + 1
@@ -410,14 +416,18 @@ def _propose(
410
416
# TODO: adapt enable_dbo later
411
417
(num_input_tokens , num_tokens_across_dp , with_prefill ,
412
418
_ ) = self .runner ._sync_metadata_across_dp (
413
- num_tokens , self .runner .with_prefill , False )
419
+ num_input_tokens , self .runner .with_prefill , False )
414
420
else :
415
421
# torchair mode can reuse self.runner.num_tokens_across_dp
416
422
num_tokens_across_dp = self .runner .num_tokens_across_dp
417
423
with_prefill = self .runner .with_prefill
418
424
419
425
moe_comm_method = self .runner ._select_moe_comm_method (
420
426
num_input_tokens , with_prefill )
427
+ batch_descriptor = BatchDescriptor (num_tokens = num_input_tokens ,
428
+ uniform_decode = False )
429
+ aclgraph_runtime_mode , batch_descriptor = \
430
+ self .runner .aclgraph_dispatcher .dispatch (batch_descriptor )
421
431
422
432
for step in range (self .num_speculative_tokens ):
423
433
with set_ascend_forward_context (
@@ -428,6 +438,7 @@ def _propose(
428
438
num_tokens_across_dp = num_tokens_across_dp ,
429
439
reserved_mc2_mask = self .runner .reserved_mc2_mask ,
430
440
moe_comm_method = moe_comm_method ,
441
+ aclgraph_runtime_mode = aclgraph_runtime_mode ,
431
442
in_profile_run = self .runner .in_profile_run ,
432
443
num_actual_tokens = num_tokens ):
433
444
with ProfileExecuteDuration ().capture_async ('mtp_forward' ):
0 commit comments