Skip to content

Commit 551edea

Browse files
committed
feat: support compile torchair graph while warming up
Signed-off-by: boying <897013703@qq.com>
1 parent a0c3e9b commit 551edea

File tree

3 files changed

+182
-41
lines changed

3 files changed

+182
-41
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,44 @@ def _get_graph_runner_block_tables(
225225
max_blocks] = block_tables[:num_seqs, :
226226
max_blocks]
227227

228-
return graph_block_tables
228+
return graph_block_tables[:num_seqs, :max_blocks]
229+
230+
def build_dummy(self, num_reqs: int,
231+
num_actual_tokens: int) -> AscendMLAMetadata:
232+
device = self.runner.device
233+
_, max_blocks = self.runner.graph_block_tables.shape
234+
block_table = torch.zeros((num_reqs, max_blocks),
235+
dtype=torch.int32,
236+
device=device)
237+
block_table = self._get_graph_runner_block_tables(
238+
num_reqs, block_table)
239+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
240+
input_positions = torch.zeros(num_reqs,
241+
dtype=torch.int32,
242+
device=device).long()
243+
slot_mapping = torch.full((num_reqs, ),
244+
PAD_SLOT_ID,
245+
dtype=torch.int32,
246+
device=device)
247+
decode_metadata = AscendMLADecodeMetadata(
248+
input_positions=input_positions,
249+
block_table=block_table,
250+
seq_lens=seq_lens,
251+
seq_lens_list=seq_lens.tolist(),
252+
max_seq_lens=1)
253+
return self.metadata_cls( # type: ignore
254+
num_input_tokens=num_actual_tokens,
255+
num_actual_tokens=num_actual_tokens,
256+
slot_mapping=slot_mapping,
257+
head_dim=self.runner.model_config.get_head_size(),
258+
num_decodes=1,
259+
num_decode_tokens=1,
260+
num_prefills=0,
261+
attn_mask=self.runner.attn_mask,
262+
attn_state=AscendAttentionState.DecodeOnly,
263+
prefill=None,
264+
decode=decode_metadata,
265+
)
229266

230267
def build(self,
231268
num_reqs: int,
@@ -307,7 +344,7 @@ def build(self,
307344
block_table = torch.cat([block_table, block_table_padding],
308345
dim=0)
309346
block_table = self._get_graph_runner_block_tables(
310-
num_seqs, block_table)
347+
num_seqs + graph_pad_size, block_table)
311348
padding_0 = torch.zeros(graph_pad_size,
312349
dtype=input_positions.dtype,
313350
device=input_positions.device)

vllm_ascend/models/deepseek_v2.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@
3636
from vllm.attention import Attention, AttentionMetadata
3737
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
3838
get_current_vllm_config)
39-
from vllm.distributed import (get_dp_group, get_pp_group,
39+
from vllm.distributed import (get_pp_group,
4040
get_tensor_model_parallel_world_size,
4141
get_tp_group, tensor_model_parallel_all_reduce)
42+
from vllm.distributed.parallel_state import get_dp_group
4243
from vllm.forward_context import get_forward_context
4344
from vllm.model_executor.layers.activation import SiluAndMul
4445
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -206,15 +207,14 @@ def __init__(
206207
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
207208

208209
self.dp_size = get_dp_group().world_size
209-
210-
self.tp_group = get_tp_group().device_group
211210
self.tp_rank = get_tp_group().rank_in_group
212211

213-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
214-
attn_metadata = get_forward_context().attn_metadata
215-
# when profile runs, force experts to load balanced tokens
216-
# to avoid high memory consumption on a single rank.
217-
# TODO: need a better flag to indicate whether in profile run or not.
212+
def forward(
213+
self,
214+
hidden_states: torch.Tensor,
215+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
216+
if attn_metadata is None:
217+
attn_metadata = get_forward_context().attn_metadata
218218
if attn_metadata is None:
219219
# for profile run
220220
is_prefill = True
@@ -540,7 +540,11 @@ def forward(
540540
# Fully Connected
541541
hidden_states, residual = self.post_attention_layernorm(
542542
hidden_states, residual)
543-
hidden_states = self.mlp(hidden_states)
543+
544+
if isinstance(self.mlp, CustomDeepseekV2MoE):
545+
hidden_states = self.mlp(hidden_states, attn_metadata)
546+
else:
547+
hidden_states = self.mlp(hidden_states)
544548

545549
if isinstance(
546550
self.mlp,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 130 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
import numpy as np
2929
import numpy.typing as npt
3030
import torch
31+
import torch._dynamo.cache_size
3132
import torch.nn as nn
3233
from vllm.attention import AttentionType, get_attn_backend
3334
from vllm.attention.layer import Attention
3435
from vllm.config import CompilationLevel, VllmConfig
36+
from vllm.distributed import get_tensor_model_parallel_world_size
3537
from vllm.distributed.parallel_state import get_pp_group
3638
from vllm.forward_context import set_forward_context
3739
from vllm.inputs import INPUT_REGISTRY
@@ -69,7 +71,9 @@
6971
else:
7072
xgr = LazyLoader("xgr", globals(), "xgrammar")
7173

72-
import vllm.envs as envs
74+
import vllm.envs as envs_vllm
75+
76+
import vllm_ascend.envs as envs_ascend
7377

7478

7579
@dataclass
@@ -321,13 +325,39 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
321325
self.sampler = Sampler()
322326
self.enable_torchair_graph_mode = False
323327
self.use_cached_npu_graph = False
328+
self.torchair_graph_batch_sizes = []
324329
additional_config = vllm_config.additional_config
325330
if additional_config:
326331
self.enable_torchair_graph_mode = additional_config.get(
327332
"enable_graph_mode",
328333
False) and self.vllm_config.model_config.use_mla
329334
self.use_cached_npu_graph = additional_config.get(
330335
"use_cached_npu_graph", False)
336+
if additional_config.get("trace_recompiles", False):
337+
torch._logging.set_logs(recompiles=True)
338+
self.torchair_graph_batch_sizes = additional_config.get(
339+
"torchair_graph_batch_sizes", [])
340+
if not isinstance(self.torchair_graph_batch_sizes, list):
341+
logger.warning("torchair_graph_batch_sizes must be list[int]")
342+
self.torchair_graph_batch_sizes = []
343+
if len(self.torchair_graph_batch_sizes
344+
) == 0 and additional_config.get(
345+
"init_torchair_graph_batch_sizes", False):
346+
self.init_torchair_graph_batch_sizes()
347+
348+
if len(self.torchair_graph_batch_sizes) == 0:
349+
#If MC2 is enabled, torchair_graph_batch_size should pad to tp_size
350+
if envs_ascend.VLLM_ENABLE_MC2:
351+
self.torchair_graph_batch_sizes = [
352+
self.scheduler_config.max_num_seqs
353+
]
354+
else:
355+
self.torchair_graph_batch_sizes = [
356+
1, self.scheduler_config.max_num_seqs
357+
]
358+
359+
torch._dynamo.cache_size.config.cache_size_limit += len(
360+
self.torchair_graph_batch_sizes)
331361

332362
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
333363
"""Update the cached states and the persistent batch with the scheduler
@@ -605,7 +635,10 @@ def _process_reqs(
605635

606636
# Add graph_pad_size here
607637
if self.enable_torchair_graph_mode:
608-
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
638+
batchsize = len(seq_lens)
639+
padded_batch_size = self.select_torchair_padded_batchsize(
640+
batchsize)
641+
graph_pad_size = padded_batch_size - batchsize
609642
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
610643

611644
attn_metadata = self.attn_metadata_builder.build( # type: ignore
@@ -630,11 +663,8 @@ def _process_reqs(
630663
input_ids = self.input_ids[:num_input_tokens]
631664

632665
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
633-
padding = torch.zeros(graph_pad_size,
634-
dtype=input_ids.dtype,
635-
device=input_ids.device)
636-
input_ids = torch.cat([input_ids, padding])
637-
positions = torch.cat([positions, padding])
666+
input_ids = self.input_ids[:padded_batch_size]
667+
positions = self.positions[:padded_batch_size]
638668

639669
# Run forward pass
640670
with set_forward_context(attn_metadata,
@@ -1039,7 +1069,11 @@ def _profile_multimodal(self) -> None:
10391069
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
10401070

10411071
@torch.inference_mode()
1042-
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
1072+
def _dummy_run(
1073+
self,
1074+
num_tokens: int,
1075+
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill,
1076+
) -> torch.Tensor:
10431077
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
10441078
# for dummy run with LoRA so that the num_reqs collectively
10451079
# has num_tokens in total.
@@ -1083,12 +1117,35 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor:
10831117
})
10841118

10851119
with set_forward_context(None, self.vllm_config):
1086-
hidden_states = model(
1087-
input_ids=input_ids,
1088-
positions=positions,
1089-
intermediate_tensors=intermediate_tensors,
1090-
inputs_embeds=inputs_embeds)
1091-
return hidden_states
1120+
if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly:
1121+
attn_metadata = self.attn_metadata_builder.build_dummy(
1122+
num_reqs=num_tokens, num_actual_tokens=1)
1123+
torch._dynamo.mark_static(input_ids)
1124+
torch._dynamo.mark_static(positions)
1125+
torch._dynamo.mark_static(attn_metadata.decode.block_table)
1126+
torch._dynamo.mark_static(
1127+
attn_metadata.decode.input_positions)
1128+
torch._dynamo.mark_static(attn_metadata.slot_mapping)
1129+
for kv in self.kv_caches:
1130+
assert isinstance(kv,
1131+
tuple), "kv_cache must be a tuple"
1132+
torch._dynamo.mark_static(kv[0])
1133+
torch._dynamo.mark_static(kv[1])
1134+
hidden_states = self.compile_model(
1135+
input_ids=input_ids,
1136+
positions=positions,
1137+
intermediate_tensors=intermediate_tensors,
1138+
inputs_embeds=None,
1139+
kv_caches=self.kv_caches,
1140+
attn_metadata=attn_metadata,
1141+
)
1142+
else:
1143+
hidden_states = model(
1144+
input_ids=input_ids,
1145+
positions=positions,
1146+
intermediate_tensors=intermediate_tensors,
1147+
inputs_embeds=inputs_embeds)
1148+
return hidden_states
10921149

10931150
def profile_run(self) -> None:
10941151
# Profile with multimodal encoder & encoder cache.
@@ -1163,13 +1220,13 @@ def load_model(self) -> None:
11631220
self.compile_model = torch.compile(
11641221
self.model,
11651222
dynamic=True,
1166-
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
1223+
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
11671224
backend=npu_backend)
11681225
else:
11691226
self.compile_model = torchair.inference.cache_compile(
11701227
self.model.forward,
11711228
dynamic=True,
1172-
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
1229+
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
11731230
config=config,
11741231
ge_cache=False)
11751232

@@ -1287,25 +1344,45 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
12871344
return kv_cache_spec
12881345

12891346
def capture_model(self) -> None:
1290-
if not self.use_aclgraph:
1291-
logger.warning(
1292-
"Skipping NPU graph capture. Please add "
1293-
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
1294-
return
1295-
12961347
start_time = time.perf_counter()
12971348
start_free_npu_memory = torch.npu.mem_get_info()[0]
1298-
1299-
# Trigger ACL graph capture for specific shapes.
1300-
# Capture the large shapes first so that the smaller shapes
1301-
# can reuse the memory pool allocated for the large shapes.
1302-
with graph_capture(device=self.device):
1303-
for num_tokens in reversed(self.aclgraph_batch_sizes):
1349+
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
1350+
# torchair graph capture can cause some issues, so now we just
1351+
# temporarily split the codepath for the two different graph patterns.
1352+
if self.enable_torchair_graph_mode:
1353+
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
1354+
graph_num = len(torchair_graph_batch_sizes)
1355+
logger.info(
1356+
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
1357+
0.5 * graph_num, 1.5 * graph_num)
1358+
attn_state = AscendAttentionState.DecodeOnly
1359+
# Trigger torchair graph capture for specific shapes.
1360+
# Capture the large shapes first so that the smaller shapes
1361+
# can reuse the memory pool allocated for the large shapes.
1362+
for idx, num_tokens in enumerate(
1363+
reversed(torchair_graph_batch_sizes)):
13041364
for _ in range(self.vllm_config.compilation_config.
13051365
cudagraph_num_of_warmups):
1366+
self._dummy_run(num_tokens, attn_state)
1367+
self._dummy_run(num_tokens, attn_state)
1368+
logger.info("Batchsize %d is compiled successfully: %d/%d.",
1369+
num_tokens, idx + 1, graph_num)
1370+
elif self.use_aclgraph:
1371+
# Trigger ACL graph capture for specific shapes.
1372+
# Capture the large shapes first so that the smaller shapes
1373+
# can reuse the memory pool allocated for the large shapes.
1374+
with graph_capture(device=self.device):
1375+
for num_tokens in reversed(self.aclgraph_batch_sizes):
1376+
for _ in range(self.vllm_config.compilation_config.
1377+
cudagraph_num_of_warmups):
1378+
self._dummy_run(num_tokens)
13061379
self._dummy_run(num_tokens)
1307-
self._dummy_run(num_tokens)
1308-
1380+
else:
1381+
logger.warning(
1382+
"Skipping NPU graph capture. Please add -O %s to use ACL graphs. "
1383+
"Or add --additional_config={'enable_graph_mode': True} to use torchair graphs",
1384+
CompilationLevel.PIECEWISE)
1385+
return
13091386
end_time = time.perf_counter()
13101387
end_free_npu_memory = torch.npu.mem_get_info()[0]
13111388
elapsed_time = end_time - start_time
@@ -1345,3 +1422,26 @@ def _generate_draft_token_ids(
13451422
else:
13461423
draft_token_ids.append(drafter_output.tolist())
13471424
return draft_token_ids
1425+
1426+
def init_torchair_graph_batch_sizes(self):
1427+
tp_size = get_tensor_model_parallel_world_size()
1428+
batch_size_step = 8
1429+
largest_batch_size = 1
1430+
1431+
if envs_ascend.VLLM_ENABLE_MC2:
1432+
batch_size_step = max(batch_size_step, tp_size)
1433+
largest_batch_size = batch_size_step
1434+
while (largest_batch_size < 8):
1435+
self.torchair_graph_batch_sizes.append(largest_batch_size)
1436+
largest_batch_size *= 2
1437+
1438+
while (largest_batch_size <= self.scheduler_config.max_num_seqs):
1439+
self.torchair_graph_batch_sizes.append(largest_batch_size)
1440+
largest_batch_size += batch_size_step
1441+
1442+
def select_torchair_padded_batchsize(self, batchsize: int):
1443+
selected_batchsize = self.max_num_reqs
1444+
for padded_batchsize in self.torchair_graph_batch_sizes:
1445+
if batchsize <= padded_batchsize < selected_batchsize:
1446+
selected_batchsize = padded_batchsize
1447+
return selected_batchsize

0 commit comments

Comments
 (0)