Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@
# ** File: worker/patch_common/patch_distributed.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.distributed.parallel_state.GroupCoordinator`
# (1) __init__()
# Why:
# The original GroupCoordinator initialization lacks pg_options to generate new
# process group with customized options.
# How:
# Inject HCCL options during process group initialization.
# Related PR (if no, explain why):
# Need a PR to vllm to support a dictionary as input while initializing distributed
# environment (e.g., Dict[str, torch.distributed.ProcessGroupHCCL.Options])
# https://github.com/vllm-project/vllm/pull/25417
# Future Plan:
# Remove this patch when vllm merges this PR.
# (2) all_to_all()
# Why:
# vllm doesn't support all_to_all for GroupCoordinator.
# How:
Expand Down
76 changes: 71 additions & 5 deletions vllm_ascend/patch/worker/patch_common/patch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,82 @@
# limitations under the License.
#

from typing import List, Optional
from typing import List, Optional, Union

import torch
import vllm
from vllm.distributed.parallel_state import GroupCoordinator
from torch.distributed import Backend
from vllm.distributed.parallel_state import (GroupCoordinator,
_get_unique_name, _register_group)

from vllm_ascend.distributed.communicator import NPUCommunicator
from vllm_ascend.utils import create_hccl_pg_options


class GroupCoordinatorPatch(GroupCoordinator):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
group_ranks: list[list[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_device_communicator: bool, # whether to use device communicator
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)

self.rank = torch.distributed.get_rank()
self.local_rank = local_rank

self_device_group = None
self_cpu_group = None
hccl_pg_options = create_hccl_pg_options(group_name)

for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks,
backend=torch_distributed_backend,
pg_options=hccl_pg_options)

# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self_device_group = device_group
self_cpu_group = cpu_group

assert self_cpu_group is not None
assert self_device_group is not None

self.cpu_group = self_cpu_group
self.device_group = self_device_group
self.device = torch.npu.current_device()

self.use_device_communicator = use_device_communicator
self.device_communicator = None
if use_device_communicator and self.world_size > 1:
self.device_communicator = NPUCommunicator(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
)

from vllm.distributed.device_communicators.shm_broadcast import \
MessageQueue
self.mq_broadcaster: Optional[MessageQueue] = None
if use_message_queue_broadcaster and self.world_size > 1:
self.mq_broadcaster = MessageQueue.create_from_process_group(
self.cpu_group, 1 << 22, 6)

self.use_custom_op_call = False
self.use_cpu_custom_send_recv = False

def all_to_all(self,
input_: torch.Tensor,
Expand All @@ -41,9 +106,10 @@ def all_to_all(self,
assert -input_.dim() <= gather_dim < input_.dim(), (
f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}"
)
assert self.device_communicator is not None, "device_communicator should be initialized when world_size > 1"
return self.device_communicator.all_to_all(input_, scatter_dim,
gather_dim, scatter_sizes,
gather_sizes)


vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch
88 changes: 88 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
_SLEEP_MODE_ENABLED = None
_CURRENT_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
_DEFAULT_BUFFER_SIZE = 200
_MIN_DP_BUFFER_SIZE = 50
_MIN_MC2_BUFFER_SIZE = 1024


def is_310p():
Expand Down Expand Up @@ -634,3 +637,88 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
return nullcontext()
assert target_stream is not None
return torch.npu.stream(target_stream)


def create_hccl_pg_options(group_name: str):
options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
hccl_config = get_hccl_config_for_pg_options(group_name)
if hccl_config is not None:
options.hccl_config = hccl_config
return options


def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
hccl_config_map = {
"dp": {
"hccl_buffer_size": calculate_dp_buffer_size()
},
"mc2": {
"hccl_buffer_size": calculate_mc2_buffer_size()
},
}
return hccl_config_map.get(group_name, get_default_buffer_config())


def get_default_buffer_config() -> dict:
return {"hccl_buffer_size": _DEFAULT_BUFFER_SIZE}


def calculate_dp_buffer_size() -> int:
"""
formula of dp buffer size:
dp_size + 2 (flags: with_prefill and enable_dbo)
"""
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
dp_size = vllm_config.parallel_config.data_parallel_size
int32_size = torch.iinfo(torch.int32).bits // 8
dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024))
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)


def calculate_mc2_buffer_size() -> int:
"""
formula of mc2 buffer size:
2 * (local_routed_expert_num * max_bs_per_rank * ep_world_size * align512(align32(2 * H) + 64) +
(K + shared_expert_num) * max_bs_per_rank * align512(2 * H))
"""
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
dp_size = vllm_config.parallel_config.data_parallel_size
tp_size = vllm_config.parallel_config.tensor_parallel_size
# since mc2 process group is only used with enable_expert_parallel,
# it is unnecessary to consider non-ep case.
ep_world_size = dp_size * tp_size
# FIXME: try to get local_routed_expert_num and redundant_expert_num (default: 2) elegantly
global_routed_expert_num = getattr(vllm_config.model_config.hf_config,
'n_routed_experts', 1)
local_routed_expert_num = math.ceil(
global_routed_expert_num / ep_world_size) + 2
# tokens are passed to shread experts without mc2 ops.
shared_expert_num = 0

max_bs_per_rank = math.ceil(vllm_config.scheduler_config.max_num_seqs /
tp_size)
# take MTP into consideration and it is better to align vllm speculative method later.
mtp_spec_token_num = (
vllm_config.speculative_config.num_speculative_tokens + 1 if
(vllm_config.speculative_config is not None
and vllm_config.speculative_config.method == 'deepseek_mtp') else 1)
max_bs_per_rank *= mtp_spec_token_num
H = vllm_config.model_config.hf_config.hidden_size
K = getattr(vllm_config.model_config.hf_config, 'num_experts_per_tok', 1)

aligned_2H_32 = _round_up(2 * H, 32)
aligned_2H_512 = _round_up(2 * H, 512)

# local_routed_expert_num * max_bs_per_rank * ep_world_size * Align512(Align32(2 * H) + 64)
part1_base = aligned_2H_32 + 64
aligned_part1 = _round_up(part1_base, 512)
local_expert_component = local_routed_expert_num * max_bs_per_rank * ep_world_size * aligned_part1

# (K + shared_expert_num) * max_bs_per_rank * Align512(2 * H)
shared_expert_component = (
K + shared_expert_num) * max_bs_per_rank * aligned_2H_512
mc2_buffer_size = math.ceil(
2 * (local_expert_component + shared_expert_component) / (1024 * 1024))
return max(mc2_buffer_size, _MIN_MC2_BUFFER_SIZE)
Loading