|
17 | 17 | # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
18 | 18 | #
|
19 | 19 |
|
| 20 | +import types |
20 | 21 | from typing import Optional
|
21 | 22 |
|
22 | 23 | import torch
|
| 24 | +import torch.distributed as dist |
| 25 | +import torch.nn as nn |
23 | 26 | import torch_npu
|
| 27 | +import vllm.envs as envs_vllm |
24 | 28 | from vllm.config import VllmConfig
|
| 29 | +from vllm.distributed import get_tensor_model_parallel_world_size |
| 30 | +from vllm.distributed.parallel_state import get_dp_group |
25 | 31 | from vllm.forward_context import get_forward_context
|
26 | 32 | from vllm.logger import logger
|
27 | 33 |
|
| 34 | +import vllm_ascend.envs as envs_ascend |
| 35 | +from vllm_ascend.ascend_config import get_ascend_config |
28 | 36 | from vllm_ascend.platform import NPUPlatform
|
29 | 37 | from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
30 | 38 | check_torchair_cache_exist,
|
31 | 39 | register_torchair_model,
|
32 | 40 | write_kv_cache_bytes_to_file)
|
33 | 41 | from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
34 |
| - maybe_converting_weight_acl_format) |
| 42 | + is_310p, maybe_converting_weight_acl_format) |
35 | 43 | from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
36 | 44 |
|
37 | 45 |
|
38 | 46 | class NPUTorchairModelRunner(NPUModelRunner):
|
39 | 47 |
|
40 | 48 | def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
41 | 49 | super().__init__(vllm_config, device)
|
| 50 | + ascend_config = get_ascend_config() |
| 51 | + self.new_kv_cache_bytes = -1 |
| 52 | + self.torchair_compiled_model = None # type: ignore |
| 53 | + self.torchair_compiled_models = {} # type: ignore |
| 54 | + self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph |
| 55 | + self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes |
| 56 | + if ascend_config.torchair_graph_config.graph_batch_sizes_init: |
| 57 | + self.init_torchair_graph_batch_sizes() |
| 58 | + |
| 59 | + self.check_torchair_graph_batch_sizes() |
| 60 | + |
| 61 | + torch._dynamo.cache_size.config.cache_size_limit += len( |
| 62 | + self.torchair_graph_batch_sizes) |
| 63 | + torch._dynamo.config.capture_dynamic_output_shape_ops = True |
| 64 | + torch._logging.set_logs( |
| 65 | + recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES) |
| 66 | + |
| 67 | + self._check_batch_sizes_consistency() |
42 | 68 | register_torchair_model()
|
43 | 69 |
|
44 | 70 | def _get_forward_metadata_across_dp_and_pad(
|
@@ -180,3 +206,215 @@ def _capture_model(self):
|
180 | 206 | if self.new_kv_cache_bytes > 0:
|
181 | 207 | write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
182 | 208 | self.new_kv_cache_bytes)
|
| 209 | + |
| 210 | + def _use_aclgraph(self) -> bool: |
| 211 | + return False |
| 212 | + |
| 213 | + def _check_batch_sizes_consistency(self) -> None: |
| 214 | + if not dist.is_initialized(): |
| 215 | + return |
| 216 | + |
| 217 | + local = torch.tensor(self.torchair_graph_batch_sizes, |
| 218 | + device="cpu", |
| 219 | + dtype=torch.int32) |
| 220 | + gathered_graph_batch_size = local.clone() |
| 221 | + dist.all_reduce(gathered_graph_batch_size, |
| 222 | + group=get_dp_group().cpu_group) |
| 223 | + expected = local * self.dp_size |
| 224 | + |
| 225 | + if not torch.equal(gathered_graph_batch_size, expected): |
| 226 | + diff_idxs = (gathered_graph_batch_size != expected).nonzero( |
| 227 | + as_tuple=False).flatten().tolist() |
| 228 | + raise AssertionError( |
| 229 | + f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n" |
| 230 | + f"Local (rank {self.dp_rank}): {local.tolist()}\n" |
| 231 | + f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n" |
| 232 | + f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}" |
| 233 | + ) |
| 234 | + |
| 235 | + def _update_graph_pad_size(self, with_prefill, graph_pad_size): |
| 236 | + if not with_prefill: |
| 237 | + self.graph_pad_size = graph_pad_size |
| 238 | + else: |
| 239 | + super()._update_graph_pad_size(with_prefill, graph_pad_size) |
| 240 | + |
| 241 | + def _update_input_ids_and_positions(self, input_ids, positions, |
| 242 | + num_input_tokens, with_prefill, |
| 243 | + padded_num_tokens_across_dp): |
| 244 | + """Override from NPUModelRunner to update input_ids and positions""" |
| 245 | + input_ids, positions = super()._update_input_ids_and_positions( |
| 246 | + input_ids, positions, num_input_tokens, with_prefill, |
| 247 | + padded_num_tokens_across_dp) |
| 248 | + |
| 249 | + if not with_prefill: |
| 250 | + input_ids = self.input_ids[:padded_num_tokens_across_dp] |
| 251 | + positions = self.positions[:padded_num_tokens_across_dp] |
| 252 | + return input_ids, positions |
| 253 | + |
| 254 | + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, |
| 255 | + padded_num_tokens_across_dp, |
| 256 | + input_ids, positions, |
| 257 | + intermediate_tensors, |
| 258 | + inputs_embeds): |
| 259 | + model_kwargs = { |
| 260 | + "kv_caches": self.kv_caches, |
| 261 | + "attn_metadata": attn_metadata |
| 262 | + } |
| 263 | + if not with_prefill: |
| 264 | + maybe_converting_weight_acl_format(self.model, |
| 265 | + ACL_FORMAT_FRACTAL_NZ) |
| 266 | + |
| 267 | + compiled_model = self._get_torchair_lazy_compiled_model( |
| 268 | + padded_num_tokens_across_dp) |
| 269 | + hidden_states = compiled_model( |
| 270 | + input_ids=input_ids, |
| 271 | + positions=positions, |
| 272 | + intermediate_tensors=intermediate_tensors, |
| 273 | + inputs_embeds=inputs_embeds, |
| 274 | + **model_kwargs, |
| 275 | + ) |
| 276 | + else: |
| 277 | + assert self.model is not None |
| 278 | + maybe_converting_weight_acl_format(self.model, |
| 279 | + ACL_FORMAT_FRACTAL_ND) |
| 280 | + |
| 281 | + hidden_states = self.model( |
| 282 | + input_ids=input_ids, |
| 283 | + positions=positions, |
| 284 | + intermediate_tensors=intermediate_tensors, |
| 285 | + inputs_embeds=inputs_embeds, |
| 286 | + **model_kwargs, |
| 287 | + ) |
| 288 | + return hidden_states |
| 289 | + |
| 290 | + def _get_torchair_lazy_compiled_model(self, batch_size: int): |
| 291 | + if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: |
| 292 | + raise ValueError( |
| 293 | + f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" |
| 294 | + ) |
| 295 | + |
| 296 | + compiled_model = self.torchair_compiled_models.get( |
| 297 | + batch_size |
| 298 | + ) if self.use_cached_npu_graph else self.torchair_compiled_model |
| 299 | + |
| 300 | + if compiled_model: |
| 301 | + return compiled_model |
| 302 | + |
| 303 | + import torchair # type: ignore |
| 304 | + from torchair import patch_for_hcom # type: ignore |
| 305 | + |
| 306 | + patch_for_hcom() |
| 307 | + |
| 308 | + if is_310p(): |
| 309 | + # on 300I Duo platform, we need to patch broadcast. however, this patch will be |
| 310 | + # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. |
| 311 | + from vllm_ascend.patch.platform.patch_common.patch_distributed import \ |
| 312 | + communication_adaptation_310p |
| 313 | + communication_adaptation_310p() |
| 314 | + |
| 315 | + config = torchair.CompilerConfig() |
| 316 | + config.experimental_config.frozen_parameter = True |
| 317 | + # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to |
| 318 | + # disable it on 300I Duo platform now. |
| 319 | + config.experimental_config.tiling_schedule_optimize = not is_310p() |
| 320 | + config.experimental_config.enable_view_optimize = \ |
| 321 | + get_ascend_config().torchair_graph_config.enable_view_optimize |
| 322 | + torch.npu.set_compile_mode(jit_compile=False) |
| 323 | + if not self.use_cached_npu_graph: |
| 324 | + npu_backend = torchair.get_npu_backend(compiler_config=config) |
| 325 | + self.torchair_compiled_model = torch.compile( |
| 326 | + self.model, |
| 327 | + dynamic=True, |
| 328 | + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, |
| 329 | + backend=npu_backend) |
| 330 | + return self.torchair_compiled_model |
| 331 | + else: |
| 332 | + # Generate a new forward proxy code object to prevent the invalidation of |
| 333 | + # compilation cache caused by dynamo retracing |
| 334 | + forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" |
| 335 | + forward_fn = self.model.forward |
| 336 | + code = forward_fn.__code__ |
| 337 | + # Mark code object with a new proxy name |
| 338 | + modified_code = code.replace(co_name=forward_proxy_name, ) |
| 339 | + |
| 340 | + modified_func = types.FunctionType(modified_code, |
| 341 | + forward_fn.__globals__, |
| 342 | + name=forward_proxy_name, |
| 343 | + argdefs=forward_fn.__defaults__) |
| 344 | + |
| 345 | + self.model.__dict__[forward_proxy_name] = modified_func.__get__( |
| 346 | + self.model, nn.Module) |
| 347 | + self.torchair_compiled_models[ |
| 348 | + batch_size] = torchair.inference.cache_compile( |
| 349 | + self.model.__dict__[forward_proxy_name], |
| 350 | + dynamic=True, |
| 351 | + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, |
| 352 | + config=config, |
| 353 | + ge_cache=False) |
| 354 | + return self.torchair_compiled_models[batch_size] |
| 355 | + |
| 356 | + def init_torchair_graph_batch_sizes(self): |
| 357 | + start_graph_batch_size = 4 |
| 358 | + tp_size = get_tensor_model_parallel_world_size() |
| 359 | + |
| 360 | + # NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks |
| 361 | + start_graph_batch_size = max(start_graph_batch_size, tp_size) |
| 362 | + |
| 363 | + while (start_graph_batch_size <= self.max_num_reqs): |
| 364 | + self.torchair_graph_batch_sizes.append(start_graph_batch_size) |
| 365 | + start_graph_batch_size *= 2 |
| 366 | + |
| 367 | + def select_torchair_padded_batch_size(self, batch_size: int): |
| 368 | + for padded_batch_size in self.torchair_graph_batch_sizes: |
| 369 | + if batch_size <= padded_batch_size: |
| 370 | + # we treat batch_size as num of requests |
| 371 | + return padded_batch_size |
| 372 | + raise ValueError( |
| 373 | + f"cur batch_size is invalid, torchair_graph_batch_sizes is " |
| 374 | + f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." |
| 375 | + ) |
| 376 | + |
| 377 | + def check_torchair_graph_batch_sizes(self): |
| 378 | + # return graph_batch_sizes according to the max number of tokens |
| 379 | + # first pad according to the number of requests |
| 380 | + if len(self.torchair_graph_batch_sizes) == 0: |
| 381 | + self.torchair_graph_batch_sizes = [1, self.max_num_reqs] |
| 382 | + else: |
| 383 | + self.torchair_graph_batch_sizes = sorted( |
| 384 | + self.torchair_graph_batch_sizes) |
| 385 | + while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs: |
| 386 | + self.torchair_graph_batch_sizes.pop() |
| 387 | + if len(self.torchair_graph_batch_sizes) == 0: |
| 388 | + logger.warning( |
| 389 | + "torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]" |
| 390 | + ) |
| 391 | + self.torchair_graph_batch_sizes = [1, self.max_num_reqs] |
| 392 | + if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs: |
| 393 | + self.torchair_graph_batch_sizes.append(self.max_num_reqs) |
| 394 | + |
| 395 | + # padded max number tokens = max_num_req * decode_token_per_req |
| 396 | + self.torchair_graph_batch_sizes = [ |
| 397 | + graph_batch_size * self.decode_token_per_req |
| 398 | + for graph_batch_size in self.torchair_graph_batch_sizes |
| 399 | + ] |
| 400 | + |
| 401 | + # NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size` |
| 402 | + tp_size = self.parallel_config.tensor_parallel_size |
| 403 | + if self.parallel_config.enable_expert_parallel: |
| 404 | + new_graph_batch_sizes = [] |
| 405 | + for graph_batch_size in self.torchair_graph_batch_sizes: |
| 406 | + cur_graph_batch_size = (graph_batch_size + tp_size - |
| 407 | + 1) // tp_size * tp_size |
| 408 | + if cur_graph_batch_size not in new_graph_batch_sizes and \ |
| 409 | + cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: |
| 410 | + new_graph_batch_sizes.append(cur_graph_batch_size) |
| 411 | + elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ |
| 412 | + and self.decode_token_per_req > 1: |
| 413 | + logger.warning( |
| 414 | + f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", |
| 415 | + f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." |
| 416 | + ) |
| 417 | + self.torchair_graph_batch_sizes = new_graph_batch_sizes |
| 418 | + |
| 419 | + def _build_drafter_prepare_inputs_torchair_param(self): |
| 420 | + return True |
0 commit comments