Skip to content

Commit 1cadcf3

Browse files
committed
[feat] support customized and separated hccl_buffer_size for process group initialization
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 12bcbd0 commit 1cadcf3

File tree

3 files changed

+166
-6
lines changed

3 files changed

+166
-6
lines changed

vllm_ascend/patch/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@
7575
# ** File: worker/patch_common/patch_distributed.py **
7676
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7777
# 1. `vllm.distributed.parallel_state.GroupCoordinator`
78+
# (1) __init__()
79+
# Why:
80+
# The original GroupCoordinator initialization lacks pg_options to generate new
81+
# process group with customized options.
82+
# How:
83+
# Inject HCCL options during process group initialization.
84+
# Related PR (if no, explain why):
85+
# Need a PR to vllm to support a dictionary as input while initializing distributed
86+
# environment (e.g., Dict[str, torch.distributed.ProcessGroupHCCL.Options])
87+
# Future Plan:
88+
# Remove this patch when vllm merges this PR.
89+
# (2) all_to_all()
7890
# Why:
7991
# vllm doesn't support all_to_all for GroupCoordinator.
8092
# How:

vllm_ascend/patch/worker/patch_common/patch_distributed.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,82 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import List, Optional
18+
from typing import List, Optional, Union
1919

2020
import torch
2121
import vllm
22-
from vllm.distributed.parallel_state import GroupCoordinator
22+
from torch.distributed import Backend
23+
from vllm.distributed.parallel_state import (GroupCoordinator,
24+
_get_unique_name, _register_group)
25+
26+
from vllm_ascend.distributed.communicator import NPUCommunicator
27+
from vllm_ascend.utils import create_hccl_pg_options
2328

2429

2530
class GroupCoordinatorPatch(GroupCoordinator):
2631

27-
def __init__(self, *args, **kwargs):
28-
super().__init__(*args, **kwargs)
32+
def __init__(
33+
self,
34+
group_ranks: list[list[int]],
35+
local_rank: int,
36+
torch_distributed_backend: Union[str, Backend],
37+
use_device_communicator: bool, # whether to use device communicator
38+
use_message_queue_broadcaster: bool = False,
39+
group_name: Optional[str] = None,
40+
):
41+
group_name = group_name or "anonymous"
42+
self.unique_name = _get_unique_name(group_name)
43+
_register_group(self)
44+
45+
self.rank = torch.distributed.get_rank()
46+
self.local_rank = local_rank
47+
48+
self_device_group = None
49+
self_cpu_group = None
50+
hccl_pg_options = create_hccl_pg_options(group_name)
51+
52+
for ranks in group_ranks:
53+
device_group = torch.distributed.new_group(
54+
ranks,
55+
backend=torch_distributed_backend,
56+
pg_options=hccl_pg_options)
57+
58+
# a group with `gloo` backend, to allow direct coordination between
59+
# processes through the CPU.
60+
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
61+
if self.rank in ranks:
62+
self.ranks = ranks
63+
self.world_size = len(ranks)
64+
self.rank_in_group = ranks.index(self.rank)
65+
self_device_group = device_group
66+
self_cpu_group = cpu_group
67+
68+
assert self_cpu_group is not None
69+
assert self_device_group is not None
70+
71+
self.cpu_group = self_cpu_group
72+
self.device_group = self_device_group
73+
self.device = torch.npu.current_device()
74+
75+
self.use_device_communicator = use_device_communicator
76+
self.device_communicator = None
77+
if use_device_communicator and self.world_size > 1:
78+
self.device_communicator = NPUCommunicator(
79+
cpu_group=self.cpu_group,
80+
device=self.device,
81+
device_group=self.device_group,
82+
unique_name=self.unique_name,
83+
)
84+
85+
from vllm.distributed.device_communicators.shm_broadcast import (
86+
MessageQueue)
87+
self.mq_broadcaster: Optional[MessageQueue] = None
88+
if use_message_queue_broadcaster and self.world_size > 1:
89+
self.mq_broadcaster = MessageQueue.create_from_process_group(
90+
self.cpu_group, 1 << 22, 6)
91+
92+
self.use_custom_op_call = False
93+
self.use_cpu_custom_send_recv = False
2994

3095
def all_to_all(self,
3196
input_: torch.Tensor,
@@ -46,4 +111,4 @@ def all_to_all(self,
46111
gather_sizes)
47112

48113

49-
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving
114+
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch

vllm_ascend/utils.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from contextlib import contextmanager, nullcontext
2525
from enum import Enum
2626
from threading import Lock
27-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
27+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
2828

2929
import torch
3030
import torch_npu # noqa: F401 # noqa: F401
@@ -60,6 +60,8 @@
6060
_SLEEP_MODE_ENABLED = None
6161
_CURRENT_STREAM = None
6262
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
63+
_MIN_DP_BUFFER_SIZE = 50
64+
_MIN_MC2_BUFFER_SIZE = 512
6365

6466

6567
def is_310p():
@@ -634,3 +636,84 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
634636
return nullcontext()
635637
assert target_stream is not None
636638
return torch.npu.stream(target_stream)
639+
640+
641+
def create_hccl_pg_options(group_name: str):
642+
options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
643+
hccl_config = get_hccl_config_for_pg_options(group_name)
644+
if hccl_config is not None:
645+
options.hccl_config = hccl_config
646+
return options
647+
648+
649+
def get_hccl_config_for_pg_options(group_name: str) -> Optional[Dict]:
650+
hccl_config_map = {
651+
"dp": {
652+
"hccl_buffer_size": calculate_dp_buffer_size()
653+
},
654+
"mc2": {
655+
"hccl_buffer_size": calculate_mc2_buffer_size()
656+
},
657+
}
658+
return hccl_config_map.get(group_name)
659+
660+
661+
def calculate_dp_buffer_size() -> int:
662+
"""
663+
formula of dp buffer size:
664+
dp_size + 2 (flags: with_prefill and enable_dbo)
665+
"""
666+
from vllm.config import get_current_vllm_config
667+
vllm_config = get_current_vllm_config()
668+
dp_size = vllm_config.parallel_config.data_parallel_size
669+
int32_size = torch.iinfo(torch.int32).bits // 8
670+
dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024))
671+
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)
672+
673+
674+
def calculate_mc2_buffer_size() -> int:
675+
"""
676+
formula of mc2 buffer size:
677+
2 * (local_routed_expert_num * max_bs_per_rank * ep_world_size * align512(align32(2 * H) + 64) +
678+
(K + shared_expert_num) * max_bs_per_rank * align512(2 * H))
679+
"""
680+
from vllm.config import get_current_vllm_config
681+
vllm_config = get_current_vllm_config()
682+
dp_size = vllm_config.parallel_config.data_parallel_size
683+
tp_size = vllm_config.parallel_config.tensor_parallel_size
684+
# since mc2 process group is only used with enable_expert_parallel,
685+
# it is unnecessary to consider non-ep case.
686+
ep_world_size = dp_size * tp_size
687+
# FIXME: try to get local_routed_expert_num and redundant_expert_num (default: 2) elegantly
688+
global_routed_expert_num = getattr(vllm_config.model_config.hf_config,
689+
'n_routed_experts', 1)
690+
local_routed_expert_num = math.ceil(
691+
global_routed_expert_num / ep_world_size) + 2
692+
# tokens are passed to shread experts without mc2 ops.
693+
shared_expert_num = 0
694+
695+
max_bs_per_rank = math.ceil(vllm_config.scheduler_config.max_num_seqs /
696+
tp_size)
697+
# take MTP into consideration and it is better to align vllm speculative method later.
698+
mtp_spec_token_num = (
699+
vllm_config.speculative_config.num_speculative_tokens + 1 if
700+
(vllm_config.speculative_config is not None
701+
and vllm_config.speculative_config.method == 'deepseek_mtp') else 1)
702+
max_bs_per_rank *= mtp_spec_token_num
703+
H = vllm_config.model_config.hf_config.hidden_size
704+
K = getattr(vllm_config.model_config.hf_config, 'num_experts_per_tok', 1)
705+
706+
aligned_2H_32 = _round_up(2 * H, 32)
707+
aligned_2H_512 = _round_up(2 * H, 512)
708+
709+
# local_routed_expert_num * max_bs_per_rank * ep_world_size * Align512(Align32(2 * H) + 64)
710+
part1_base = aligned_2H_32 + 64
711+
aligned_part1 = _round_up(part1_base, 512)
712+
local_expert_component = local_routed_expert_num * max_bs_per_rank * ep_world_size * aligned_part1
713+
714+
# (K + shared_expert_num) * max_bs_per_rank * Align512(2 * H)
715+
shared_expert_component = (
716+
K + shared_expert_num) * max_bs_per_rank * aligned_2H_512
717+
mc2_buffer_size = math.ceil(
718+
2 * (local_expert_component + shared_expert_component) / (1024 * 1024))
719+
return max(mc2_buffer_size, _MIN_MC2_BUFFER_SIZE)

0 commit comments

Comments
 (0)