|
29 | 29 | import numpy.typing as npt
|
30 | 30 | import torch
|
31 | 31 | import torch._dynamo.cache_size
|
32 |
| -import torch.distributed as dist |
33 | 32 | import torch.nn as nn
|
34 | 33 | from tqdm import tqdm # type: ignore
|
35 | 34 | from vllm.attention import AttentionType, get_attn_backend
|
@@ -596,18 +595,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
596 | 595 | def _get_forward_metadata_across_dp(
|
597 | 596 | self, num_tokens: int, with_prefill: bool,
|
598 | 597 | enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]:
|
599 |
| - |
600 |
| - # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo) |
601 |
| - num_tokens_across_dp = torch.zeros(self.dp_size + 2, |
602 |
| - dtype=torch.int32, |
603 |
| - device="cpu") |
604 |
| - num_tokens_across_dp[self.dp_rank] = num_tokens |
605 |
| - num_tokens_across_dp[-2] = int(with_prefill) |
606 |
| - num_tokens_across_dp[-1] = int(not enable_dbo) |
607 |
| - dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group) |
608 |
| - with_prefill = bool(num_tokens_across_dp[-2]) |
609 |
| - enable_dbo = not bool(num_tokens_across_dp[-1]) |
610 |
| - num_tokens_across_dp = num_tokens_across_dp[:-2] |
| 598 | + local_forward_metadata = torch.tensor( |
| 599 | + [[num_tokens, with_prefill, enable_dbo]], |
| 600 | + device="npu", |
| 601 | + dtype=torch.int32) |
| 602 | + global_forward_metadata = get_dp_group().all_gather( |
| 603 | + local_forward_metadata, dim=0) |
| 604 | + num_tokens_across_dp = global_forward_metadata[:, 0].cpu() |
| 605 | + with_prefill = bool(global_forward_metadata[:, 1].any()) |
| 606 | + enable_dbo = bool(global_forward_metadata[:, 2].any()) |
611 | 607 | return num_tokens_across_dp, with_prefill, enable_dbo
|
612 | 608 |
|
613 | 609 | def _get_forward_metadata_across_dp_and_pad(
|
|
0 commit comments