28
28
import numpy as np
29
29
import numpy .typing as npt
30
30
import torch
31
+ import torch ._dynamo .cache_size
31
32
import torch .nn as nn
32
33
from vllm .attention import AttentionType , get_attn_backend
33
34
from vllm .attention .layer import Attention
34
35
from vllm .config import CompilationLevel , VllmConfig
36
+ from vllm .distributed import get_tensor_model_parallel_world_size
35
37
from vllm .distributed .parallel_state import get_pp_group
36
38
from vllm .forward_context import set_forward_context
37
39
from vllm .inputs import INPUT_REGISTRY
69
71
else :
70
72
xgr = LazyLoader ("xgr" , globals (), "xgrammar" )
71
73
72
- import vllm .envs as envs
74
+ import vllm .envs as envs_vllm
75
+
76
+ import vllm_ascend .envs as envs_ascend
73
77
74
78
75
79
@dataclass
@@ -321,13 +325,39 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
321
325
self .sampler = Sampler ()
322
326
self .enable_torchair_graph_mode = False
323
327
self .use_cached_npu_graph = False
328
+ self .torchair_graph_batch_sizes = []
324
329
additional_config = vllm_config .additional_config
325
330
if additional_config :
326
331
self .enable_torchair_graph_mode = additional_config .get (
327
332
"enable_graph_mode" ,
328
333
False ) and self .vllm_config .model_config .use_mla
329
334
self .use_cached_npu_graph = additional_config .get (
330
335
"use_cached_npu_graph" , False )
336
+ if additional_config .get ("trace_recompiles" , False ):
337
+ torch ._logging .set_logs (recompiles = True )
338
+ self .torchair_graph_batch_sizes = additional_config .get (
339
+ "torchair_graph_batch_sizes" , [])
340
+ if not isinstance (self .torchair_graph_batch_sizes , list ):
341
+ logger .warning ("torchair_graph_batch_sizes must be list[int]" )
342
+ self .torchair_graph_batch_sizes = []
343
+ if len (self .torchair_graph_batch_sizes
344
+ ) == 0 and additional_config .get (
345
+ "init_torchair_graph_batch_sizes" , False ):
346
+ self .init_torchair_graph_batch_sizes ()
347
+
348
+ if len (self .torchair_graph_batch_sizes ) == 0 :
349
+ #If MC2 is enabled, torchair_graph_batch_size should pad to tp_size
350
+ if envs_ascend .VLLM_ENABLE_MC2 :
351
+ self .torchair_graph_batch_sizes = [
352
+ self .scheduler_config .max_num_seqs
353
+ ]
354
+ else :
355
+ self .torchair_graph_batch_sizes = [
356
+ 1 , self .scheduler_config .max_num_seqs
357
+ ]
358
+
359
+ torch ._dynamo .cache_size .config .cache_size_limit += len (
360
+ self .torchair_graph_batch_sizes )
331
361
332
362
def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
333
363
"""Update the cached states and the persistent batch with the scheduler
@@ -605,7 +635,10 @@ def _process_reqs(
605
635
606
636
# Add graph_pad_size here
607
637
if self .enable_torchair_graph_mode :
608
- graph_pad_size = self .scheduler_config .max_num_seqs - len (seq_lens )
638
+ batchsize = len (seq_lens )
639
+ padded_batch_size = self .select_torchair_padded_batchsize (
640
+ batchsize )
641
+ graph_pad_size = padded_batch_size - batchsize
609
642
extra_builder_kwargs ['graph_pad_size' ] = graph_pad_size
610
643
611
644
attn_metadata = self .attn_metadata_builder .build ( # type: ignore
@@ -630,11 +663,8 @@ def _process_reqs(
630
663
input_ids = self .input_ids [:num_input_tokens ]
631
664
632
665
if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
633
- padding = torch .zeros (graph_pad_size ,
634
- dtype = input_ids .dtype ,
635
- device = input_ids .device )
636
- input_ids = torch .cat ([input_ids , padding ])
637
- positions = torch .cat ([positions , padding ])
666
+ input_ids = self .input_ids [:padded_batch_size ]
667
+ positions = self .positions [:padded_batch_size ]
638
668
639
669
# Run forward pass
640
670
with set_forward_context (attn_metadata ,
@@ -1039,7 +1069,11 @@ def _profile_multimodal(self) -> None:
1039
1069
self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
1040
1070
1041
1071
@torch .inference_mode ()
1042
- def _dummy_run (self , num_tokens : int ) -> torch .Tensor :
1072
+ def _dummy_run (
1073
+ self ,
1074
+ num_tokens : int ,
1075
+ attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill ,
1076
+ ) -> torch .Tensor :
1043
1077
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
1044
1078
# for dummy run with LoRA so that the num_reqs collectively
1045
1079
# has num_tokens in total.
@@ -1083,12 +1117,35 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor:
1083
1117
})
1084
1118
1085
1119
with set_forward_context (None , self .vllm_config ):
1086
- hidden_states = model (
1087
- input_ids = input_ids ,
1088
- positions = positions ,
1089
- intermediate_tensors = intermediate_tensors ,
1090
- inputs_embeds = inputs_embeds )
1091
- return hidden_states
1120
+ if self .enable_torchair_graph_mode and attn_state == AscendAttentionState .DecodeOnly :
1121
+ attn_metadata = self .attn_metadata_builder .build_dummy (
1122
+ num_reqs = num_tokens , num_actual_tokens = 1 )
1123
+ torch ._dynamo .mark_static (input_ids )
1124
+ torch ._dynamo .mark_static (positions )
1125
+ torch ._dynamo .mark_static (attn_metadata .decode .block_table )
1126
+ torch ._dynamo .mark_static (
1127
+ attn_metadata .decode .input_positions )
1128
+ torch ._dynamo .mark_static (attn_metadata .slot_mapping )
1129
+ for kv in self .kv_caches :
1130
+ assert isinstance (kv ,
1131
+ tuple ), "kv_cache must be a tuple"
1132
+ torch ._dynamo .mark_static (kv [0 ])
1133
+ torch ._dynamo .mark_static (kv [1 ])
1134
+ hidden_states = self .compile_model (
1135
+ input_ids = input_ids ,
1136
+ positions = positions ,
1137
+ intermediate_tensors = intermediate_tensors ,
1138
+ inputs_embeds = None ,
1139
+ kv_caches = self .kv_caches ,
1140
+ attn_metadata = attn_metadata ,
1141
+ )
1142
+ else :
1143
+ hidden_states = model (
1144
+ input_ids = input_ids ,
1145
+ positions = positions ,
1146
+ intermediate_tensors = intermediate_tensors ,
1147
+ inputs_embeds = inputs_embeds )
1148
+ return hidden_states
1092
1149
1093
1150
def profile_run (self ) -> None :
1094
1151
# Profile with multimodal encoder & encoder cache.
@@ -1163,13 +1220,13 @@ def load_model(self) -> None:
1163
1220
self .compile_model = torch .compile (
1164
1221
self .model ,
1165
1222
dynamic = True ,
1166
- fullgraph = envs .VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE ,
1223
+ fullgraph = envs_vllm .VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE ,
1167
1224
backend = npu_backend )
1168
1225
else :
1169
1226
self .compile_model = torchair .inference .cache_compile (
1170
1227
self .model .forward ,
1171
1228
dynamic = True ,
1172
- fullgraph = envs .VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE ,
1229
+ fullgraph = envs_vllm .VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE ,
1173
1230
config = config ,
1174
1231
ge_cache = False )
1175
1232
@@ -1287,25 +1344,45 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
1287
1344
return kv_cache_spec
1288
1345
1289
1346
def capture_model (self ) -> None :
1290
- if not self .use_aclgraph :
1291
- logger .warning (
1292
- "Skipping NPU graph capture. Please add "
1293
- "-O %s to use NPU graphs." , CompilationLevel .PIECEWISE )
1294
- return
1295
-
1296
1347
start_time = time .perf_counter ()
1297
1348
start_free_npu_memory = torch .npu .mem_get_info ()[0 ]
1298
-
1299
- # Trigger ACL graph capture for specific shapes.
1300
- # Capture the large shapes first so that the smaller shapes
1301
- # can reuse the memory pool allocated for the large shapes.
1302
- with graph_capture (device = self .device ):
1303
- for num_tokens in reversed (self .aclgraph_batch_sizes ):
1349
+ # TODO(NeverRaR): Calling graph_capture(device=self.device) in
1350
+ # torchair graph capture can cause some issues, so now we just
1351
+ # temporarily split the codepath for the two different graph patterns.
1352
+ if self .enable_torchair_graph_mode :
1353
+ torchair_graph_batch_sizes = self .torchair_graph_batch_sizes
1354
+ graph_num = len (torchair_graph_batch_sizes )
1355
+ logger .info (
1356
+ "Capturing torchair graph, this usually takes %.1f~%.1f mins." ,
1357
+ 0.5 * graph_num , 1.5 * graph_num )
1358
+ attn_state = AscendAttentionState .DecodeOnly
1359
+ # Trigger torchair graph capture for specific shapes.
1360
+ # Capture the large shapes first so that the smaller shapes
1361
+ # can reuse the memory pool allocated for the large shapes.
1362
+ for idx , num_tokens in enumerate (
1363
+ reversed (torchair_graph_batch_sizes )):
1304
1364
for _ in range (self .vllm_config .compilation_config .
1305
1365
cudagraph_num_of_warmups ):
1366
+ self ._dummy_run (num_tokens , attn_state )
1367
+ self ._dummy_run (num_tokens , attn_state )
1368
+ logger .info ("Batchsize %d is compiled successfully: %d/%d." ,
1369
+ num_tokens , idx + 1 , graph_num )
1370
+ elif self .use_aclgraph :
1371
+ # Trigger ACL graph capture for specific shapes.
1372
+ # Capture the large shapes first so that the smaller shapes
1373
+ # can reuse the memory pool allocated for the large shapes.
1374
+ with graph_capture (device = self .device ):
1375
+ for num_tokens in reversed (self .aclgraph_batch_sizes ):
1376
+ for _ in range (self .vllm_config .compilation_config .
1377
+ cudagraph_num_of_warmups ):
1378
+ self ._dummy_run (num_tokens )
1306
1379
self ._dummy_run (num_tokens )
1307
- self ._dummy_run (num_tokens )
1308
-
1380
+ else :
1381
+ logger .warning (
1382
+ "Skipping NPU graph capture. Please add -O %s to use ACL graphs. "
1383
+ "Or add --additional_config={'enable_graph_mode': True} to use torchair graphs" ,
1384
+ CompilationLevel .PIECEWISE )
1385
+ return
1309
1386
end_time = time .perf_counter ()
1310
1387
end_free_npu_memory = torch .npu .mem_get_info ()[0 ]
1311
1388
elapsed_time = end_time - start_time
@@ -1345,3 +1422,26 @@ def _generate_draft_token_ids(
1345
1422
else :
1346
1423
draft_token_ids .append (drafter_output .tolist ())
1347
1424
return draft_token_ids
1425
+
1426
+ def init_torchair_graph_batch_sizes (self ):
1427
+ tp_size = get_tensor_model_parallel_world_size ()
1428
+ batch_size_step = 8
1429
+ largest_batch_size = 1
1430
+
1431
+ if envs_ascend .VLLM_ENABLE_MC2 :
1432
+ batch_size_step = max (batch_size_step , tp_size )
1433
+ largest_batch_size = batch_size_step
1434
+ while (largest_batch_size < 8 ):
1435
+ self .torchair_graph_batch_sizes .append (largest_batch_size )
1436
+ largest_batch_size *= 2
1437
+
1438
+ while (largest_batch_size <= self .scheduler_config .max_num_seqs ):
1439
+ self .torchair_graph_batch_sizes .append (largest_batch_size )
1440
+ largest_batch_size += batch_size_step
1441
+
1442
+ def select_torchair_padded_batchsize (self , batchsize : int ):
1443
+ selected_batchsize = self .max_num_reqs
1444
+ for padded_batchsize in self .torchair_graph_batch_sizes :
1445
+ if batchsize <= padded_batchsize < selected_batchsize :
1446
+ selected_batchsize = padded_batchsize
1447
+ return selected_batchsize
0 commit comments