Skip to content

Commit f0a39a2

Browse files
committed
[Refactor] Refactor stateless_init_torch_distributed_process_group to platform.py
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 61430b2 commit f0a39a2

File tree

3 files changed

+161
-110
lines changed

3 files changed

+161
-110
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from torch.distributed import ProcessGroup
2+
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
3+
_get_default_timeout,
4+
is_nccl_available)
5+
from torch.distributed.rendezvous import rendezvous
6+
from vllm.distributed import utils
7+
8+
9+
def stateless_init_torch_distributed_process_group(
10+
host: str, port: int, rank: int, world_size: int,
11+
backend: str) -> ProcessGroup:
12+
"""
13+
A replacement for `torch.distributed.init_process_group` that does not
14+
pollute the global state. The created ProcessGroup object can be used for
15+
some operations such as `allreduce`, because it does not depend on the
16+
global rank. However, some operations such as `broadcast` cannot be used
17+
because it depends on the global rank.
18+
19+
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
20+
21+
This function is useful when we are not sure about the total number of
22+
processes in the process group. For example, we may have process
23+
1, 2, ..., 8 who want to communicate, and process 9 might be the same
24+
process as process 1, or it might be a different process; process 10
25+
might be the same process as process 5, or it might be a different process.
26+
In this case, how can we reliably form a communication channel within
27+
process 9 and 10, without affecting the communication channel within
28+
process 1, 2, ..., 8?
29+
30+
One possible solution is to figure out if process 9 and 10 are the same
31+
as process 1 and 5 beforehand, and then form a communication channel
32+
based on the information, adjusting the ranks and world_size etc. However,
33+
figuring out the information is not always easy, and it will interfere
34+
with the main communication channel.
35+
36+
Our solution is to always form a communication channel with process 1, 2,
37+
..., 8, and then use this function to form another communication channel
38+
with process 9 and 10. This way, regardless of whether process 9 and 10
39+
are the same as process 1 and 5, the main communication channel is
40+
always formed with process 1, 2, ..., 8, and the additional communication
41+
channel is formed with process 9 and 10.
42+
"""
43+
init_method = f"tcp://{host}:{port}"
44+
backend = Backend(backend) # it is basically string
45+
timeout = _get_default_timeout(backend)
46+
47+
store, rank, world_size = next(
48+
rendezvous(init_method, rank, world_size, timeout=timeout))
49+
store.set_timeout(timeout)
50+
51+
group_rank = rank
52+
group_size = world_size
53+
54+
# Use a PrefixStore to avoid accidental overrides of keys used by
55+
# different systems (e.g. RPC) in case the store is multi-tenant.
56+
prefix_store = PrefixStore(init_method, store)
57+
58+
# TODO(Yizhou): The reason we need to set options while vllm does not
59+
# seems to be related to the version of PyTorch. In the latest version,
60+
# there is no need to set options. While in the older version, 2.5.1
61+
# specifically, we need to set options.
62+
options = ProcessGroup.Options(backend=backend)
63+
pg: ProcessGroup = ProcessGroup(
64+
prefix_store,
65+
group_rank,
66+
group_size,
67+
options,
68+
)
69+
if backend == "gloo":
70+
from torch.distributed.distributed_c10d import ProcessGroupGloo
71+
backend_class = ProcessGroupGloo(prefix_store,
72+
group_rank,
73+
group_size,
74+
timeout=timeout)
75+
backend_type = ProcessGroup.BackendType.GLOO
76+
device = torch.device("cpu")
77+
elif backend == "nccl":
78+
assert is_nccl_available()
79+
from torch.distributed.distributed_c10d import ProcessGroupNCCL
80+
81+
backend_options = ProcessGroupNCCL.Options()
82+
backend_options._timeout = timeout
83+
84+
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
85+
backend_options)
86+
backend_type = ProcessGroup.BackendType.NCCL
87+
device = torch.device("cuda")
88+
elif backend == "hccl":
89+
from torch.distributed import is_hccl_available
90+
assert is_hccl_available()
91+
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
92+
backend_options = ProcessGroupHCCL.Options()
93+
backend_options._timeout = timeout
94+
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
95+
backend_options)
96+
device = torch.device("npu")
97+
backend_class._set_sequence_number_for_group()
98+
backend_type = ProcessGroup.BackendType.CUSTOM
99+
pg._register_backend(device, backend_type, backend_class)
100+
return pg
101+
else:
102+
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
103+
104+
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
105+
# implemented in the 2.5.1 version of PyTorch. But we need to set it
106+
# after the latest version is released.
107+
# pg._set_default_backend(backend_type)
108+
backend_class._set_sequence_number_for_group()
109+
110+
pg._register_backend(device, backend_type, backend_class)
111+
112+
return pg
113+
114+
115+
utils.stateless_init_torch_distributed_process_group = stateless_init_torch_distributed_process_group

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 1 addition & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,9 @@
2222
import vllm.distributed
2323
import vllm.envs as envs
2424
from torch.distributed import ProcessGroup
25-
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
26-
_get_default_timeout,
27-
is_nccl_available)
28-
from torch.distributed.rendezvous import rendezvous
2925
from vllm.config import ParallelConfig, VllmConfig
3026
from vllm.v1.engine.core import DPEngineCoreProc
27+
from vllm.utils import stateless_init_torch_distributed_process_group
3128

3229

3330
def ascend_destroy_model_parallel():
@@ -49,112 +46,6 @@ def ascend_destroy_model_parallel():
4946
destory_ascend_model_parallel()
5047

5148

52-
def stateless_init_torch_distributed_process_group(
53-
host: str, port: int, rank: int, world_size: int,
54-
backend: str) -> ProcessGroup:
55-
"""
56-
A replacement for `torch.distributed.init_process_group` that does not
57-
pollute the global state. The created ProcessGroup object can be used for
58-
some operations such as `allreduce`, because it does not depend on the
59-
global rank. However, some operations such as `broadcast` cannot be used
60-
because it depends on the global rank.
61-
62-
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
63-
64-
This function is useful when we are not sure about the total number of
65-
processes in the process group. For example, we may have process
66-
1, 2, ..., 8 who want to communicate, and process 9 might be the same
67-
process as process 1, or it might be a different process; process 10
68-
might be the same process as process 5, or it might be a different process.
69-
In this case, how can we reliably form a communication channel within
70-
process 9 and 10, without affecting the communication channel within
71-
process 1, 2, ..., 8?
72-
73-
One possible solution is to figure out if process 9 and 10 are the same
74-
as process 1 and 5 beforehand, and then form a communication channel
75-
based on the information, adjusting the ranks and world_size etc. However,
76-
figuring out the information is not always easy, and it will interfere
77-
with the main communication channel.
78-
79-
Our solution is to always form a communication channel with process 1, 2,
80-
..., 8, and then use this function to form another communication channel
81-
with process 9 and 10. This way, regardless of whether process 9 and 10
82-
are the same as process 1 and 5, the main communication channel is
83-
always formed with process 1, 2, ..., 8, and the additional communication
84-
channel is formed with process 9 and 10.
85-
"""
86-
init_method = f"tcp://{host}:{port}"
87-
backend = Backend(backend) # it is basically string
88-
timeout = _get_default_timeout(backend)
89-
90-
store, rank, world_size = next(
91-
rendezvous(init_method, rank, world_size, timeout=timeout))
92-
store.set_timeout(timeout)
93-
94-
group_rank = rank
95-
group_size = world_size
96-
97-
# Use a PrefixStore to avoid accidental overrides of keys used by
98-
# different systems (e.g. RPC) in case the store is multi-tenant.
99-
prefix_store = PrefixStore(init_method, store)
100-
101-
# TODO(Yizhou): The reason we need to set options while vllm does not
102-
# seems to be related to the version of PyTorch. In the latest version,
103-
# there is no need to set options. While in the older version, 2.5.1
104-
# specifically, we need to set options.
105-
options = ProcessGroup.Options(backend=backend)
106-
pg: ProcessGroup = ProcessGroup(
107-
prefix_store,
108-
group_rank,
109-
group_size,
110-
options,
111-
)
112-
if backend == "gloo":
113-
from torch.distributed.distributed_c10d import ProcessGroupGloo
114-
backend_class = ProcessGroupGloo(prefix_store,
115-
group_rank,
116-
group_size,
117-
timeout=timeout)
118-
backend_type = ProcessGroup.BackendType.GLOO
119-
device = torch.device("cpu")
120-
elif backend == "nccl":
121-
assert is_nccl_available()
122-
from torch.distributed.distributed_c10d import ProcessGroupNCCL
123-
124-
backend_options = ProcessGroupNCCL.Options()
125-
backend_options._timeout = timeout
126-
127-
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
128-
backend_options)
129-
backend_type = ProcessGroup.BackendType.NCCL
130-
device = torch.device("cuda")
131-
elif backend == "hccl":
132-
from torch.distributed import is_hccl_available
133-
assert is_hccl_available()
134-
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
135-
backend_options = ProcessGroupHCCL.Options()
136-
backend_options._timeout = timeout
137-
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
138-
backend_options)
139-
device = torch.device("npu")
140-
backend_class._set_sequence_number_for_group()
141-
backend_type = ProcessGroup.BackendType.CUSTOM
142-
pg._register_backend(device, backend_type, backend_class)
143-
return pg
144-
else:
145-
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
146-
147-
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
148-
# implemented in the 2.5.1 version of PyTorch. But we need to set it
149-
# after the latest version is released.
150-
# pg._set_default_backend(backend_type)
151-
backend_class._set_sequence_number_for_group()
152-
153-
pg._register_backend(device, backend_type, backend_class)
154-
155-
return pg
156-
157-
15849
def parallel_config_get_dp_port(self) -> int:
15950
"""
16051
We might need to initialize process groups in multiple

vllm_ascend/platform.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717

1818
import logging
1919
import os
20+
from datetime import timedelta
2021
from typing import TYPE_CHECKING, Optional, Tuple
2122

2223
import torch
24+
from torch.distributed import ProcessGroup
25+
from torch.distributed.distributed_c10d import PrefixStore
2326
import vllm.envs as envs
2427
from vllm.logger import logger
2528
from vllm.platforms import Platform, PlatformEnum
@@ -249,3 +252,45 @@ def get_piecewise_backend_cls(cls) -> str:
249252
Get piecewise backend class for piecewise graph.
250253
"""
251254
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa
255+
256+
@classmethod
257+
def stateless_init_device_torch_dist_pg(
258+
cls,
259+
backend: str,
260+
prefix_store: PrefixStore,
261+
group_rank: int,
262+
group_size: int,
263+
timeout: timedelta,
264+
) -> ProcessGroup:
265+
from torch.distributed import is_hccl_available
266+
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
267+
268+
assert is_hccl_available()
269+
270+
# TODO(Yizhou): The reason we need to set options while vllm does not
271+
# seems to be related to the version of PyTorch. In the latest version,
272+
# there is no need to set options. While in the older version, 2.5.1
273+
# specifically, we need to set options.
274+
options = ProcessGroup.Options(backend=backend)
275+
pg: ProcessGroup = ProcessGroup(
276+
prefix_store,
277+
group_rank,
278+
group_size,
279+
options,
280+
)
281+
282+
backend_options = ProcessGroupHCCL.Options()
283+
backend_options._timeout = timeout
284+
285+
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
286+
backend_options)
287+
device = torch.device("npu")
288+
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
289+
# implemented in the 2.5.1 version of PyTorch. But we need to set it
290+
# after the latest version is released.
291+
# pg._set_default_backend(backend_type)
292+
backend_class._set_sequence_number_for_group()
293+
backend_type = ProcessGroup.BackendType.CUSTOM
294+
295+
pg._register_backend(device, backend_type, backend_class)
296+
return pg

0 commit comments

Comments
 (0)