Skip to content

[Feature] Optimize forward metadata collection across dp ranks #1857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions vllm_ascend/torchair/torchair_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,3 @@ def determine_available_memory(self) -> int:
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory

return available_kv_cache_memory

def _get_max_num_tokens_and_with_prefill(self):
"""Override _get_max_num_tokens_and_with_prefill to update max_num_tokens."""

max_num_tokens, with_prefill = super(
)._get_max_num_tokens_and_with_prefill()
if not with_prefill:
max_num_tokens = self.model_runner.select_torchair_padded_batch_size(
max_num_tokens)
return max_num_tokens, with_prefill
65 changes: 40 additions & 25 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
import numpy.typing as npt
import torch
import torch._dynamo.cache_size
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ReduceOp
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
Expand Down Expand Up @@ -562,16 +560,29 @@
self.input_batch.refresh_sampling_metadata()

def _get_forward_metadata_across_dp(
self, total_num_scheduled_tokens: int,
with_prefill: bool) -> tuple[int, bool]:
forward_metadata = torch.tensor(
[total_num_scheduled_tokens, with_prefill],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata,
op=ReduceOp.MAX,
group=get_dp_group().cpu_group)
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
self, num_tokens: int,
with_prefill: bool) -> tuple[torch.Tensor, bool]:
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
device="npu",
dtype=torch.int32).unsqueeze(0)
global_forward_metadata = get_dp_group().all_gather(
local_forward_metadata, dim=0)
num_tokens_across_dp = global_forward_metadata[:, 0].cpu()
with_prefill = bool(global_forward_metadata[:, 1].any())

if self.torchair_graph_enabled and not with_prefill:
max_num_tokens = int(num_tokens_across_dp.max().item())
dummy_num_tokens = self.select_torchair_padded_batch_size(
max_num_tokens)
else:
dummy_num_tokens = 1

# If num_tokens is -1, this indicates a dummy batch and we need to reset
# num_tokens accordingly.
num_tokens = dummy_num_tokens if num_tokens == -1 else num_tokens
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
dummy_num_tokens)
return num_tokens, num_tokens_across_dp, with_prefill

Check failure on line 585 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "tuple[int, Any, bool]", expected "tuple[Any, bool]") [return-value]

Check failure on line 585 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "tuple[int, Any, bool]", expected "tuple[Any, bool]") [return-value]

Check failure on line 585 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "tuple[int, Any, bool]", expected "tuple[Any, bool]") [return-value]

def get_eagle_atten_dict(
self,
Expand Down Expand Up @@ -1033,22 +1044,22 @@
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]

num_tokens_across_dp = None
if self.dp_size > 1:
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
total_num_scheduled_tokens, with_prefill)
_, num_tokens_across_dp, with_prefill = \
self._get_forward_metadata_across_dp(num_input_tokens,

Check failure on line 1050 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need more than 2 values to unpack (3 expected) [misc]

Check failure on line 1050 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need more than 2 values to unpack (3 expected) [misc]

Check failure on line 1050 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need more than 2 values to unpack (3 expected) [misc]
with_prefill)
max_num_tokens = int(num_tokens_across_dp.max().item())

Check failure on line 1052 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "max" [attr-defined]

Check failure on line 1052 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "max" [attr-defined]
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill

# Add graph_pad_size here
if self.torchair_graph_enabled and not with_prefill:
if self.dp_size > 1:
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
else:
padded_batch_size = self.select_torchair_padded_batch_size(
total_num_scheduled_tokens)
max_num_tokens = (max_num_tokens
if self.dp_size > 1 else num_input_tokens)
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
graph_pad_size = padded_batch_size - total_num_scheduled_tokens

extra_builder_kwargs['graph_pad_size'] = graph_pad_size

if self.vllm_config.model_config.use_mla:
Expand Down Expand Up @@ -1126,7 +1137,8 @@
# Run forward pass
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp):
with ProfileExecuteDuration().capture_async("forward"):
model_kwargs = {}
if self.torchair_graph_enabled:
Expand Down Expand Up @@ -1605,6 +1617,7 @@
num_tokens: int,
is_compile: bool = False,
with_prefill: bool = True,
num_tokens_across_dp: Optional[int] = None,
) -> torch.Tensor:
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
Expand Down Expand Up @@ -1646,9 +1659,11 @@
for k, v in self.intermediate_tensors.items()
})

with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
Expand Down
29 changes: 16 additions & 13 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,20 +265,23 @@
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)

def _get_max_num_tokens_and_with_prefill(self):
max_num_tokens = 1
with_prefill = False
if self.model_runner.dp_size > 1:
max_num_tokens, with_prefill = self.model_runner._get_forward_metadata_across_dp(
max_num_tokens, with_prefill)
return max_num_tokens, with_prefill

def execute_dummy_batch(self) -> None:
max_num_tokens, with_prefill = self._get_max_num_tokens_and_with_prefill(
)
self.model_runner._dummy_run(max_num_tokens,
is_compile=False,
with_prefill=with_prefill)
if self.runner.dp_size <= 1:
raise ValueError(
"Dummy batch execution should only be "
"performed with data parallelism enabled, but got "
f"dp_size={self.runner.dp_size}.")

# Indicate to other data parallel (DP) ranks that this is a dummy run by
# using '-1' as the num_tokens flag. The actual batch size will be
# determined and set within the model runner after synchronization
# across DP ranks.
num_tokens, num_tokens_across_dp, with_prefill = \
self.model_runner._get_forward_metadata_across_dp(-1, False)

Check failure on line 280 in vllm_ascend/worker/worker_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need more than 2 values to unpack (3 expected) [misc]

Check failure on line 280 in vllm_ascend/worker/worker_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need more than 2 values to unpack (3 expected) [misc]
self.runner._dummy_run(num_tokens,
is_compile=False,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill)

def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
Expand Down
Loading