24
24
init_distributed_environment )
25
25
from vllm .utils import update_environment_variables
26
26
27
+ from tests .e2e .conftest import cleanup_dist_env_and_memory
27
28
from vllm_ascend .distributed .device_communicators .pyhccl import \
28
29
PyHcclCommunicator
29
30
31
+ os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
32
+
33
+ multiprocessing .set_start_method ("spawn" , force = True )
34
+
35
+
36
+ def _worker_entry (env , fn ):
37
+ # `multiprocessing.Process` cannot accept environment variables directly
38
+ # so we need to pass the environment variables as arguments
39
+ # and update the environment variables in the function
40
+ update_environment_variables (env )
41
+
42
+ rank = int (os .environ ['RANK' ])
43
+ local_rank = int (os .environ ['LOCAL_RANK' ])
44
+ word_size = int (os .environ ['WORLD_SIZE' ])
45
+
46
+ distributed_init_method = "tcp://localhost:12345"
47
+
48
+ device = torch .device (f"npu:{ local_rank } " )
49
+ torch .npu .set_device (device )
50
+
51
+ init_distributed_environment (
52
+ world_size = word_size ,
53
+ rank = rank ,
54
+ distributed_init_method = distributed_init_method ,
55
+ local_rank = local_rank ,
56
+ backend = "hccl" )
57
+ fn ()
58
+ cleanup_dist_env_and_memory ()
59
+
30
60
31
61
def distributed_run (fn , world_size ):
32
62
number_of_processes = world_size
@@ -37,9 +67,7 @@ def distributed_run(fn, world_size):
37
67
env ['LOCAL_RANK' ] = str (i )
38
68
env ['WORLD_SIZE' ] = str (number_of_processes )
39
69
env ['LOCAL_WORLD_SIZE' ] = str (number_of_processes )
40
- env ['MASTER_ADDR' ] = 'localhost'
41
- env ['MASTER_PORT' ] = '12345'
42
- p = multiprocessing .Process (target = fn , args = (env , ))
70
+ p = multiprocessing .Process (target = _worker_entry , args = (env , fn ))
43
71
processes .append (p )
44
72
p .start ()
45
73
@@ -50,22 +78,6 @@ def distributed_run(fn, world_size):
50
78
assert p .exitcode == 0
51
79
52
80
53
- def worker_fn_wrapper (fn ):
54
- # `multiprocessing.Process` cannot accept environment variables directly
55
- # so we need to pass the environment variables as arguments
56
- # and update the environment variables in the function
57
- def wrapped_fn (env ):
58
- update_environment_variables (env )
59
- local_rank = os .environ ['LOCAL_RANK' ]
60
- device = torch .device (f"npu:{ local_rank } " )
61
- torch .npu .set_device (device )
62
- init_distributed_environment (backend = "hccl" )
63
- fn ()
64
-
65
- return wrapped_fn
66
-
67
-
68
- @worker_fn_wrapper
69
81
def worker_fn ():
70
82
pynccl_comm = PyHcclCommunicator (get_world_group ().cpu_group ,
71
83
device = get_world_group ().device )
@@ -76,11 +88,10 @@ def worker_fn():
76
88
assert torch .all (tensor == pynccl_comm .world_size ).cpu ().item ()
77
89
78
90
79
- # def test_pyhccl():
80
- # distributed_run(worker_fn, 2 )
91
+ def test_pyhccl ():
92
+ distributed_run (worker_fn , 4 )
81
93
82
94
83
- @worker_fn_wrapper
84
95
def broadcast_worker_fn ():
85
96
# Test broadcast for every root rank.
86
97
# Essentially this is an all-gather operation.
@@ -106,5 +117,5 @@ def broadcast_worker_fn():
106
117
assert torch .all (recv_tensors [i ] == i ).cpu ().item ()
107
118
108
119
109
- # def test_pyhccl_broadcast():
110
- # distributed_run(broadcast_worker_fn, 4)
120
+ def test_pyhccl_broadcast ():
121
+ distributed_run (broadcast_worker_fn , 4 )
0 commit comments