Skip to content

Commit ca448e1

Browse files
committed
feat: optimize forward metadata collection across dp ranks
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 957b0b6 commit ca448e1

File tree

3 files changed

+56
-48
lines changed

3 files changed

+56
-48
lines changed

vllm_ascend/torchair/torchair_worker.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,3 @@ def determine_available_memory(self) -> int:
5252
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
5353

5454
return available_kv_cache_memory
55-
56-
def _get_max_num_tokens_and_with_prefill(self):
57-
"""Override _get_max_num_tokens_and_with_prefill to update max_num_tokens."""
58-
59-
max_num_tokens, with_prefill = super(
60-
)._get_max_num_tokens_and_with_prefill()
61-
if not with_prefill:
62-
max_num_tokens = self.model_runner.select_torchair_padded_batch_size(
63-
max_num_tokens)
64-
return max_num_tokens, with_prefill

vllm_ascend/worker/model_runner_v1.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@
3030
import numpy.typing as npt
3131
import torch
3232
import torch._dynamo.cache_size
33-
import torch.distributed as dist
3433
import torch.nn as nn
35-
from torch.distributed import ReduceOp
3634
from vllm.attention import AttentionType, get_attn_backend
3735
from vllm.attention.layer import Attention
3836
from vllm.config import CompilationLevel, VllmConfig
@@ -562,16 +560,29 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
562560
self.input_batch.refresh_sampling_metadata()
563561

564562
def _get_forward_metadata_across_dp(
565-
self, total_num_scheduled_tokens: int,
566-
with_prefill: bool) -> tuple[int, bool]:
567-
forward_metadata = torch.tensor(
568-
[total_num_scheduled_tokens, with_prefill],
569-
device="cpu",
570-
dtype=torch.int32)
571-
dist.all_reduce(forward_metadata,
572-
op=ReduceOp.MAX,
573-
group=get_dp_group().cpu_group)
574-
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
563+
self, num_tokens: int,
564+
with_prefill: bool) -> tuple[torch.Tensor, bool]:
565+
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
566+
device="npu",
567+
dtype=torch.int32).unsqueeze(0)
568+
global_forward_metadata = get_dp_group().all_gather(
569+
local_forward_metadata, dim=0)
570+
num_tokens_across_dp = global_forward_metadata[:, 0].cpu()
571+
with_prefill = bool(global_forward_metadata[:, 1].any())
572+
573+
if self.torchair_graph_enabled and not with_prefill:
574+
max_num_tokens = int(num_tokens_across_dp.max().item())
575+
dummy_num_tokens = self.select_torchair_padded_batch_size(
576+
max_num_tokens)
577+
else:
578+
dummy_num_tokens = 1
579+
580+
# If num_tokens is -1, this indicates a dummy batch and we need to reset
581+
# num_tokens accordingly.
582+
num_tokens = dummy_num_tokens if num_tokens == -1 else num_tokens
583+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
584+
dummy_num_tokens)
585+
return num_tokens, num_tokens_across_dp, with_prefill
575586

576587
def get_eagle_atten_dict(
577588
self,
@@ -1033,22 +1044,22 @@ def _process_reqs(
10331044
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
10341045
]
10351046

1047+
num_tokens_across_dp = None
10361048
if self.dp_size > 1:
1037-
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
1038-
total_num_scheduled_tokens, with_prefill)
1049+
_, num_tokens_across_dp, with_prefill = \
1050+
self._get_forward_metadata_across_dp(num_input_tokens,
1051+
with_prefill)
1052+
max_num_tokens = int(num_tokens_across_dp.max().item())
10391053
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
10401054
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
10411055

10421056
# Add graph_pad_size here
10431057
if self.torchair_graph_enabled and not with_prefill:
1044-
if self.dp_size > 1:
1045-
padded_batch_size = self.select_torchair_padded_batch_size(
1046-
max_num_tokens)
1047-
else:
1048-
padded_batch_size = self.select_torchair_padded_batch_size(
1049-
total_num_scheduled_tokens)
1058+
max_num_tokens = (max_num_tokens
1059+
if self.dp_size > 1 else num_input_tokens)
1060+
padded_batch_size = self.select_torchair_padded_batch_size(
1061+
max_num_tokens)
10501062
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
1051-
10521063
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
10531064

10541065
if self.vllm_config.model_config.use_mla:
@@ -1126,7 +1137,8 @@ def _process_reqs(
11261137
# Run forward pass
11271138
with set_forward_context(attn_metadata,
11281139
self.vllm_config,
1129-
num_tokens=num_input_tokens):
1140+
num_tokens=num_input_tokens,
1141+
num_tokens_across_dp=num_tokens_across_dp):
11301142
with ProfileExecuteDuration().capture_async("forward"):
11311143
model_kwargs = {}
11321144
if self.torchair_graph_enabled:
@@ -1605,6 +1617,7 @@ def _dummy_run(
16051617
num_tokens: int,
16061618
is_compile: bool = False,
16071619
with_prefill: bool = True,
1620+
num_tokens_across_dp: Optional[int] = None,
16081621
) -> torch.Tensor:
16091622
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
16101623
# for dummy run with LoRA so that the num_reqs collectively
@@ -1646,9 +1659,11 @@ def _dummy_run(
16461659
for k, v in self.intermediate_tensors.items()
16471660
})
16481661

1649-
with set_forward_context(None,
1650-
self.vllm_config,
1651-
num_tokens=num_tokens):
1662+
with set_forward_context(
1663+
None,
1664+
self.vllm_config,
1665+
num_tokens=num_tokens,
1666+
num_tokens_across_dp=num_tokens_across_dp):
16521667
if self.torchair_graph_enabled and not with_prefill:
16531668
attn_metadata = self.attn_metadata_builder.build_dummy(
16541669
num_reqs=num_tokens, num_actual_tokens=1)

vllm_ascend/worker/worker_v1.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -265,20 +265,23 @@ def list_loras(self) -> set[int]:
265265
def pin_lora(self, lora_id: int) -> bool:
266266
return self.model_runner.pin_lora(lora_id)
267267

268-
def _get_max_num_tokens_and_with_prefill(self):
269-
max_num_tokens = 1
270-
with_prefill = False
271-
if self.model_runner.dp_size > 1:
272-
max_num_tokens, with_prefill = self.model_runner._get_forward_metadata_across_dp(
273-
max_num_tokens, with_prefill)
274-
return max_num_tokens, with_prefill
275-
276268
def execute_dummy_batch(self) -> None:
277-
max_num_tokens, with_prefill = self._get_max_num_tokens_and_with_prefill(
278-
)
279-
self.model_runner._dummy_run(max_num_tokens,
280-
is_compile=False,
281-
with_prefill=with_prefill)
269+
if self.runner.dp_size <= 1:
270+
raise ValueError(
271+
"Dummy batch execution should only be "
272+
"performed with data parallelism enabled, but got "
273+
f"dp_size={self.runner.dp_size}.")
274+
275+
# Indicate to other data parallel (DP) ranks that this is a dummy run by
276+
# using '-1' as the num_tokens flag. The actual batch size will be
277+
# determined and set within the model runner after synchronization
278+
# across DP ranks.
279+
num_tokens, num_tokens_across_dp, with_prefill = \
280+
self.model_runner._get_forward_metadata_across_dp(-1, False)
281+
self.runner._dummy_run(num_tokens,
282+
is_compile=False,
283+
num_tokens_across_dp=num_tokens_across_dp,
284+
with_prefill=with_prefill)
282285

283286
def _init_worker_distributed_environment(self) -> None:
284287
"""Initialize the distributed environment."""

0 commit comments

Comments
 (0)