Skip to content

Commit 0dca4c6

Browse files
authored
refact runner model v1 (#2461)
refact model runner v1 ### What this PR does / why we need it? 1. Separate the execute model logic from the prepare input logic 2. Disassemble the torchchair in model runner v1 - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@68fcd3f --------- Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent 1de16ea commit 0dca4c6

File tree

3 files changed

+368
-307
lines changed

3 files changed

+368
-307
lines changed

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 239 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,54 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
1818
#
1919

20+
import types
2021
from typing import Optional
2122

2223
import torch
24+
import torch.distributed as dist
25+
import torch.nn as nn
2326
import torch_npu
27+
import vllm.envs as envs_vllm
2428
from vllm.config import VllmConfig
29+
from vllm.distributed import get_tensor_model_parallel_world_size
30+
from vllm.distributed.parallel_state import get_dp_group
2531
from vllm.forward_context import get_forward_context
2632
from vllm.logger import logger
2733

34+
import vllm_ascend.envs as envs_ascend
35+
from vllm_ascend.ascend_config import get_ascend_config
2836
from vllm_ascend.platform import NPUPlatform
2937
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
3038
check_torchair_cache_exist,
3139
register_torchair_model,
3240
write_kv_cache_bytes_to_file)
3341
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
34-
maybe_converting_weight_acl_format)
42+
is_310p, maybe_converting_weight_acl_format)
3543
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
3644

3745

3846
class NPUTorchairModelRunner(NPUModelRunner):
3947

4048
def __init__(self, vllm_config: VllmConfig, device: torch.device):
4149
super().__init__(vllm_config, device)
50+
ascend_config = get_ascend_config()
51+
self.new_kv_cache_bytes = -1
52+
self.torchair_compiled_model = None # type: ignore
53+
self.torchair_compiled_models = {} # type: ignore
54+
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
55+
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
56+
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
57+
self.init_torchair_graph_batch_sizes()
58+
59+
self.check_torchair_graph_batch_sizes()
60+
61+
torch._dynamo.cache_size.config.cache_size_limit += len(
62+
self.torchair_graph_batch_sizes)
63+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
64+
torch._logging.set_logs(
65+
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
66+
67+
self._check_batch_sizes_consistency()
4268
register_torchair_model()
4369

4470
def _get_forward_metadata_across_dp_and_pad(
@@ -180,3 +206,215 @@ def _capture_model(self):
180206
if self.new_kv_cache_bytes > 0:
181207
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
182208
self.new_kv_cache_bytes)
209+
210+
def _use_aclgraph(self) -> bool:
211+
return False
212+
213+
def _check_batch_sizes_consistency(self) -> None:
214+
if not dist.is_initialized():
215+
return
216+
217+
local = torch.tensor(self.torchair_graph_batch_sizes,
218+
device="cpu",
219+
dtype=torch.int32)
220+
gathered_graph_batch_size = local.clone()
221+
dist.all_reduce(gathered_graph_batch_size,
222+
group=get_dp_group().cpu_group)
223+
expected = local * self.dp_size
224+
225+
if not torch.equal(gathered_graph_batch_size, expected):
226+
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
227+
as_tuple=False).flatten().tolist()
228+
raise AssertionError(
229+
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
230+
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
231+
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
232+
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
233+
)
234+
235+
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
236+
if not with_prefill:
237+
self.graph_pad_size = graph_pad_size
238+
else:
239+
super()._update_graph_pad_size(with_prefill, graph_pad_size)
240+
241+
def _update_input_ids_and_positions(self, input_ids, positions,
242+
num_input_tokens, with_prefill,
243+
padded_num_tokens_across_dp):
244+
"""Override from NPUModelRunner to update input_ids and positions"""
245+
input_ids, positions = super()._update_input_ids_and_positions(
246+
input_ids, positions, num_input_tokens, with_prefill,
247+
padded_num_tokens_across_dp)
248+
249+
if not with_prefill:
250+
input_ids = self.input_ids[:padded_num_tokens_across_dp]
251+
positions = self.positions[:padded_num_tokens_across_dp]
252+
return input_ids, positions
253+
254+
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
255+
padded_num_tokens_across_dp,
256+
input_ids, positions,
257+
intermediate_tensors,
258+
inputs_embeds):
259+
model_kwargs = {
260+
"kv_caches": self.kv_caches,
261+
"attn_metadata": attn_metadata
262+
}
263+
if not with_prefill:
264+
maybe_converting_weight_acl_format(self.model,
265+
ACL_FORMAT_FRACTAL_NZ)
266+
267+
compiled_model = self._get_torchair_lazy_compiled_model(
268+
padded_num_tokens_across_dp)
269+
hidden_states = compiled_model(
270+
input_ids=input_ids,
271+
positions=positions,
272+
intermediate_tensors=intermediate_tensors,
273+
inputs_embeds=inputs_embeds,
274+
**model_kwargs,
275+
)
276+
else:
277+
assert self.model is not None
278+
maybe_converting_weight_acl_format(self.model,
279+
ACL_FORMAT_FRACTAL_ND)
280+
281+
hidden_states = self.model(
282+
input_ids=input_ids,
283+
positions=positions,
284+
intermediate_tensors=intermediate_tensors,
285+
inputs_embeds=inputs_embeds,
286+
**model_kwargs,
287+
)
288+
return hidden_states
289+
290+
def _get_torchair_lazy_compiled_model(self, batch_size: int):
291+
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
292+
raise ValueError(
293+
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
294+
)
295+
296+
compiled_model = self.torchair_compiled_models.get(
297+
batch_size
298+
) if self.use_cached_npu_graph else self.torchair_compiled_model
299+
300+
if compiled_model:
301+
return compiled_model
302+
303+
import torchair # type: ignore
304+
from torchair import patch_for_hcom # type: ignore
305+
306+
patch_for_hcom()
307+
308+
if is_310p():
309+
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
310+
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
311+
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
312+
communication_adaptation_310p
313+
communication_adaptation_310p()
314+
315+
config = torchair.CompilerConfig()
316+
config.experimental_config.frozen_parameter = True
317+
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
318+
# disable it on 300I Duo platform now.
319+
config.experimental_config.tiling_schedule_optimize = not is_310p()
320+
config.experimental_config.enable_view_optimize = \
321+
get_ascend_config().torchair_graph_config.enable_view_optimize
322+
torch.npu.set_compile_mode(jit_compile=False)
323+
if not self.use_cached_npu_graph:
324+
npu_backend = torchair.get_npu_backend(compiler_config=config)
325+
self.torchair_compiled_model = torch.compile(
326+
self.model,
327+
dynamic=True,
328+
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
329+
backend=npu_backend)
330+
return self.torchair_compiled_model
331+
else:
332+
# Generate a new forward proxy code object to prevent the invalidation of
333+
# compilation cache caused by dynamo retracing
334+
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
335+
forward_fn = self.model.forward
336+
code = forward_fn.__code__
337+
# Mark code object with a new proxy name
338+
modified_code = code.replace(co_name=forward_proxy_name, )
339+
340+
modified_func = types.FunctionType(modified_code,
341+
forward_fn.__globals__,
342+
name=forward_proxy_name,
343+
argdefs=forward_fn.__defaults__)
344+
345+
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
346+
self.model, nn.Module)
347+
self.torchair_compiled_models[
348+
batch_size] = torchair.inference.cache_compile(
349+
self.model.__dict__[forward_proxy_name],
350+
dynamic=True,
351+
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
352+
config=config,
353+
ge_cache=False)
354+
return self.torchair_compiled_models[batch_size]
355+
356+
def init_torchair_graph_batch_sizes(self):
357+
start_graph_batch_size = 4
358+
tp_size = get_tensor_model_parallel_world_size()
359+
360+
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
361+
start_graph_batch_size = max(start_graph_batch_size, tp_size)
362+
363+
while (start_graph_batch_size <= self.max_num_reqs):
364+
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
365+
start_graph_batch_size *= 2
366+
367+
def select_torchair_padded_batch_size(self, batch_size: int):
368+
for padded_batch_size in self.torchair_graph_batch_sizes:
369+
if batch_size <= padded_batch_size:
370+
# we treat batch_size as num of requests
371+
return padded_batch_size
372+
raise ValueError(
373+
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
374+
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
375+
)
376+
377+
def check_torchair_graph_batch_sizes(self):
378+
# return graph_batch_sizes according to the max number of tokens
379+
# first pad according to the number of requests
380+
if len(self.torchair_graph_batch_sizes) == 0:
381+
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
382+
else:
383+
self.torchair_graph_batch_sizes = sorted(
384+
self.torchair_graph_batch_sizes)
385+
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
386+
self.torchair_graph_batch_sizes.pop()
387+
if len(self.torchair_graph_batch_sizes) == 0:
388+
logger.warning(
389+
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
390+
)
391+
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
392+
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
393+
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
394+
395+
# padded max number tokens = max_num_req * decode_token_per_req
396+
self.torchair_graph_batch_sizes = [
397+
graph_batch_size * self.decode_token_per_req
398+
for graph_batch_size in self.torchair_graph_batch_sizes
399+
]
400+
401+
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
402+
tp_size = self.parallel_config.tensor_parallel_size
403+
if self.parallel_config.enable_expert_parallel:
404+
new_graph_batch_sizes = []
405+
for graph_batch_size in self.torchair_graph_batch_sizes:
406+
cur_graph_batch_size = (graph_batch_size + tp_size -
407+
1) // tp_size * tp_size
408+
if cur_graph_batch_size not in new_graph_batch_sizes and \
409+
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
410+
new_graph_batch_sizes.append(cur_graph_batch_size)
411+
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
412+
and self.decode_token_per_req > 1:
413+
logger.warning(
414+
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
415+
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
416+
)
417+
self.torchair_graph_batch_sizes = new_graph_batch_sizes
418+
419+
def _build_drafter_prepare_inputs_torchair_param(self):
420+
return True

0 commit comments

Comments
 (0)