Skip to content

Commit 88fd7b8

Browse files
committed
[feat] support customized and separated hccl_buffer_size for process group initialization
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 12bcbd0 commit 88fd7b8

File tree

3 files changed

+161
-5
lines changed

3 files changed

+161
-5
lines changed

vllm_ascend/patch/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@
7575
# ** File: worker/patch_common/patch_distributed.py **
7676
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7777
# 1. `vllm.distributed.parallel_state.GroupCoordinator`
78+
# (1) __init__()
79+
# Why:
80+
# The original GroupCoordinator initialization lacks pg_options to generate new
81+
# process group with customized options.
82+
# How:
83+
# Inject HCCL options during process group initialization.
84+
# Related PR (if no, explain why):
85+
# Need a PR to vllm to support a dictionary as input while initializing distributed
86+
# environment (e.g., Dict[str, torch.distributed.ProcessGroupHCCL.Options])
87+
# Future Plan:
88+
# Remove this patch when vllm merges this PR.
89+
# (2) all_to_all()
7890
# Why:
7991
# vllm doesn't support all_to_all for GroupCoordinator.
8092
# How:

vllm_ascend/patch/worker/patch_common/patch_distributed.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,81 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import List, Optional
18+
from typing import List, Optional, Union
1919

2020
import torch
2121
import vllm
22-
from vllm.distributed.parallel_state import GroupCoordinator
22+
from torch.distributed import backend
23+
from vllm.distributed.parallel_state import (GroupCoordinator,
24+
_get_unique_name, _register_group)
25+
26+
from vllm_ascend.distributed.communicator import NPUCommunicator
2327

2428

2529
class GroupCoordinatorPatch(GroupCoordinator):
2630

27-
def __init__(self, *args, **kwargs):
28-
super().__init__(*args, **kwargs)
31+
def __init__(
32+
self,
33+
group_ranks: list[list[int]],
34+
local_rank: int,
35+
torch_distributed_backend: Union[str, Backend],
36+
use_device_communicator: bool, # whether to use device communicator
37+
use_message_queue_broadcaster: bool = False,
38+
group_name: Optional[str] = None,
39+
):
40+
group_name = group_name or "anonymous"
41+
self.unique_name = _get_unique_name(group_name)
42+
_register_group(self)
43+
44+
self.rank = torch.distributed.get_rank()
45+
self.local_rank = local_rank
46+
47+
self_device_group = None
48+
self_cpu_group = None
49+
hccl_pg_options = create_hccl_pg_options(group_name)
50+
51+
for ranks in group_ranks:
52+
device_group = torch.distributed.new_group(
53+
ranks,
54+
backend=torch_distributed_backend,
55+
pg_options=hccl_pg_options)
56+
57+
# a group with `gloo` backend, to allow direct coordination between
58+
# processes through the CPU.
59+
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
60+
if self.rank in ranks:
61+
self.ranks = ranks
62+
self.world_size = len(ranks)
63+
self.rank_in_group = ranks.index(self.rank)
64+
self_device_group = device_group
65+
self_cpu_group = cpu_group
66+
67+
assert self_cpu_group is not None
68+
assert self_device_group is not None
69+
70+
self.cpu_group = self_cpu_group
71+
self.device_group = self_device_group
72+
self.device = torch.npu.current_device()
73+
74+
self.use_device_communicator = use_device_communicator
75+
self.device_communicator = None
76+
if use_device_communicator and self.world_size > 1:
77+
self.device_communicator = NPUCommunicator(
78+
cpu_group=self.cpu_group,
79+
device=self.device,
80+
device_group=self.device_group,
81+
unique_name=self.unique_name,
82+
)
83+
84+
from vllm.distributed.device_communicators.shm_broadcast import \
85+
use_message_queue_broadcaster
86+
self.mq_broadcaster: Optional[MessageQueue] = None
87+
if use_message_queue_broadcaster and self.world_size > 1:
88+
self.mq_broadcaster = MessageQueue.create_from_process_group(
89+
self.cpu_group, 1 << 22, 6)
90+
91+
self.use_custom_op_call = False
92+
self.use_cpu_custom_send_recv = False
2993

3094
def all_to_all(self,
3195
input_: torch.Tensor,

vllm_ascend/utils.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from contextlib import contextmanager, nullcontext
2525
from enum import Enum
2626
from threading import Lock
27-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
27+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
2828

2929
import torch
3030
import torch_npu # noqa: F401 # noqa: F401
@@ -60,6 +60,8 @@
6060
_SLEEP_MODE_ENABLED = None
6161
_CURRENT_STREAM = None
6262
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
63+
_MIN_DP_BUFFER_SIZE = 50
64+
_MIN_MC2_BUFFER_SIZE = 512
6365

6466

6567
def is_310p():
@@ -634,3 +636,81 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
634636
return nullcontext()
635637
assert target_stream is not None
636638
return torch.npu.stream(target_stream)
639+
640+
def create_hccl_pg_options(group_name: str):
641+
options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
642+
hccl_config = get_hccl_config_for_pg_options(group_name)
643+
if hccl_config is not None:
644+
options.hccl_config = hccl_config
645+
return options
646+
647+
def get_hccl_config_for_pg_options(group_name: str) -> Optional[Dict]:
648+
hccl_config_map = {
649+
"dp": {
650+
"hccl_buffer_size": calculate_dp_buffer_size()
651+
},
652+
"mc2": {
653+
"hccl_buffer_size": calculate_mc2_buffer_size()
654+
},
655+
}
656+
return hccl_config_map.get(group_name)
657+
658+
659+
def calculate_dp_buffer_size() -> int:
660+
"""
661+
formula of dp buffer size:
662+
dp_size + 2 (flags: with_prefill and enable_dbo)
663+
"""
664+
from vllm.config import get_current_vllm_config
665+
vllm_config = get_current_vllm_config()
666+
dp_size = vllm_config.parallel_config.data_parallel_size
667+
int32_size = torch.iinfo(torch.int32).bits // 8
668+
dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024))
669+
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)
670+
671+
672+
def calculate_mc2_buffer_size() -> int:
673+
"""
674+
formula of mc2 buffer size:
675+
2 * (local_routed_expert_num * max_bs_per_rank * ep_world_size * align512(align32(2 * H) + 64) +
676+
(K + shared_expert_num) * max_bs_per_rank * align512(2 * H))
677+
"""
678+
from vllm.config import get_current_vllm_config
679+
vllm_config = get_current_vllm_config()
680+
dp_size = vllm_config.parallel_config.data_parallel_size
681+
tp_size = vllm_config.parallel_config.tensor_parallel_size
682+
# since mc2 process group is only used with enable_expert_parallel,
683+
# it is unnecessary to consider non-ep case.
684+
ep_world_size = dp_size * tp_size
685+
# FIXME: try to get local_routed_expert_num and redundant_expert_num (default: 2) elegantly
686+
global_routed_expert_num = getattr(vllm_config.model_config.hf_config,
687+
'n_routed_experts', 1)
688+
local_routed_expert_num = math.ceil(
689+
global_routed_expert_num / ep_world_size) + 2
690+
# tokens are passed to shread experts without mc2 ops.
691+
shared_expert_num = 0
692+
693+
max_bs_per_rank = math.ceil(vllm_config.scheduler_config.max_num_seqs /
694+
tp_size)
695+
# take MTP into consideration and it is better to align vllm speculative method later.
696+
mtp_spec_token_num = (
697+
vllm_config.speculative_config.num_speculative_tokens + 1 if
698+
(vllm_config.speculative_config is not None
699+
and vllm_config.speculative_config.method == 'deepseek_mtp') else 1)
700+
max_bs_per_rank *= mtp_spec_token_num
701+
H = vllm_config.model_config.hf_config.hidden_size
702+
K = getattr(vllm_config.model_config.hf_config, 'num_experts_per_tok', 1)
703+
704+
aligned_2H_32 = round_up(2 * H, 32)
705+
aligned_2H_512 = round_up(2 * H, 512)
706+
707+
# local_routed_expert_num * max_bs_per_rank * ep_world_size * Align512(Align32(2 * H) + 64)
708+
part1_base = aligned_2H_32 + 64
709+
aligned_part1 = round_up(part1_base, 512)
710+
local_expert_component = local_routed_expert_num * max_bs_per_rank * ep_world_size * aligned_part1
711+
712+
# (K + shared_expert_num) * max_bs_per_rank * Align512(2 * H)
713+
shared_expert_component = (
714+
K + shared_expert_num) * max_bs_per_rank * aligned_2H_512 mc2_buffer_size = math.ceil(
715+
2 * (local_expert_component + shared_expert_component) / (1024 * 1024))
716+
return max(mc2_buffer_size, _MIN_MC2_BUFFER_SIZE)

0 commit comments

Comments
 (0)