|
30 | 30 | import numpy.typing as npt
|
31 | 31 | import torch
|
32 | 32 | import torch._dynamo.cache_size
|
33 |
| -import torch.distributed as dist |
34 | 33 | import torch.nn as nn
|
35 |
| -from torch.distributed import ReduceOp |
36 | 34 | from vllm.attention import AttentionType, get_attn_backend
|
37 | 35 | from vllm.attention.layer import Attention
|
38 | 36 | from vllm.config import CompilationLevel, VllmConfig
|
@@ -562,16 +560,29 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
562 | 560 | self.input_batch.refresh_sampling_metadata()
|
563 | 561 |
|
564 | 562 | def _get_forward_metadata_across_dp(
|
565 |
| - self, total_num_scheduled_tokens: int, |
566 |
| - with_prefill: bool) -> tuple[int, bool]: |
567 |
| - forward_metadata = torch.tensor( |
568 |
| - [total_num_scheduled_tokens, with_prefill], |
569 |
| - device="cpu", |
570 |
| - dtype=torch.int32) |
571 |
| - dist.all_reduce(forward_metadata, |
572 |
| - op=ReduceOp.MAX, |
573 |
| - group=get_dp_group().cpu_group) |
574 |
| - return int(forward_metadata[0]), bool(forward_metadata[1] > 0) |
| 563 | + self, num_tokens: int, |
| 564 | + with_prefill: bool) -> tuple[torch.Tensor, bool]: |
| 565 | + local_forward_metadata = torch.tensor([num_tokens, with_prefill], |
| 566 | + device="npu", |
| 567 | + dtype=torch.int32).unsqueeze(0) |
| 568 | + global_forward_metadata = get_dp_group().all_gather( |
| 569 | + local_forward_metadata, dim=0) |
| 570 | + num_tokens_across_dp = global_forward_metadata[:, 0].cpu() |
| 571 | + with_prefill = bool(global_forward_metadata[:, 1].any()) |
| 572 | + |
| 573 | + if self.torchair_graph_enabled and not with_prefill: |
| 574 | + max_num_tokens = int(num_tokens_across_dp.max().item()) |
| 575 | + dummy_num_tokens = self.select_torchair_padded_batch_size( |
| 576 | + max_num_tokens) |
| 577 | + else: |
| 578 | + dummy_num_tokens = 1 |
| 579 | + |
| 580 | + # If num_tokens is -1, this indicates a dummy batch and we need to reset |
| 581 | + # num_tokens accordingly. |
| 582 | + num_tokens = dummy_num_tokens if num_tokens == -1 else num_tokens |
| 583 | + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, |
| 584 | + dummy_num_tokens) |
| 585 | + return num_tokens, num_tokens_across_dp, with_prefill |
575 | 586 |
|
576 | 587 | def get_eagle_atten_dict(
|
577 | 588 | self,
|
@@ -1033,22 +1044,22 @@ def _process_reqs(
|
1033 | 1044 | AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
1034 | 1045 | ]
|
1035 | 1046 |
|
| 1047 | + num_tokens_across_dp = None |
1036 | 1048 | if self.dp_size > 1:
|
1037 |
| - max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( |
1038 |
| - total_num_scheduled_tokens, with_prefill) |
| 1049 | + _, num_tokens_across_dp, with_prefill = \ |
| 1050 | + self._get_forward_metadata_across_dp(num_input_tokens, |
| 1051 | + with_prefill) |
| 1052 | + max_num_tokens = int(num_tokens_across_dp.max().item()) |
1039 | 1053 | extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
|
1040 | 1054 | extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
|
1041 | 1055 |
|
1042 | 1056 | # Add graph_pad_size here
|
1043 | 1057 | if self.torchair_graph_enabled and not with_prefill:
|
1044 |
| - if self.dp_size > 1: |
1045 |
| - padded_batch_size = self.select_torchair_padded_batch_size( |
1046 |
| - max_num_tokens) |
1047 |
| - else: |
1048 |
| - padded_batch_size = self.select_torchair_padded_batch_size( |
1049 |
| - total_num_scheduled_tokens) |
| 1058 | + max_num_tokens = (max_num_tokens |
| 1059 | + if self.dp_size > 1 else num_input_tokens) |
| 1060 | + padded_batch_size = self.select_torchair_padded_batch_size( |
| 1061 | + max_num_tokens) |
1050 | 1062 | graph_pad_size = padded_batch_size - total_num_scheduled_tokens
|
1051 |
| - |
1052 | 1063 | extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
1053 | 1064 |
|
1054 | 1065 | if self.vllm_config.model_config.use_mla:
|
@@ -1126,7 +1137,8 @@ def _process_reqs(
|
1126 | 1137 | # Run forward pass
|
1127 | 1138 | with set_forward_context(attn_metadata,
|
1128 | 1139 | self.vllm_config,
|
1129 |
| - num_tokens=num_input_tokens): |
| 1140 | + num_tokens=num_input_tokens, |
| 1141 | + num_tokens_across_dp=num_tokens_across_dp): |
1130 | 1142 | with ProfileExecuteDuration().capture_async("forward"):
|
1131 | 1143 | model_kwargs = {}
|
1132 | 1144 | if self.torchair_graph_enabled:
|
@@ -1605,6 +1617,7 @@ def _dummy_run(
|
1605 | 1617 | num_tokens: int,
|
1606 | 1618 | is_compile: bool = False,
|
1607 | 1619 | with_prefill: bool = True,
|
| 1620 | + num_tokens_across_dp: Optional[int] = None, |
1608 | 1621 | ) -> torch.Tensor:
|
1609 | 1622 | # Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
1610 | 1623 | # for dummy run with LoRA so that the num_reqs collectively
|
@@ -1646,9 +1659,11 @@ def _dummy_run(
|
1646 | 1659 | for k, v in self.intermediate_tensors.items()
|
1647 | 1660 | })
|
1648 | 1661 |
|
1649 |
| - with set_forward_context(None, |
1650 |
| - self.vllm_config, |
1651 |
| - num_tokens=num_tokens): |
| 1662 | + with set_forward_context( |
| 1663 | + None, |
| 1664 | + self.vllm_config, |
| 1665 | + num_tokens=num_tokens, |
| 1666 | + num_tokens_across_dp=num_tokens_across_dp): |
1652 | 1667 | if self.torchair_graph_enabled and not with_prefill:
|
1653 | 1668 | attn_metadata = self.attn_metadata_builder.build_dummy(
|
1654 | 1669 | num_reqs=num_tokens, num_actual_tokens=1)
|
|
0 commit comments