Skip to content

Commit dd00b8f

Browse files
committed
1.Fixed the issue that pyhccl e2e cannot run continuously with other tests. 2.Cleaned up the resources occupied by the dynamic_npugraph_batchsize e2e test.
Signed-off-by: leo-pony <nengjunma@outlook.com>
1 parent 9a3bdf2 commit dd00b8f

File tree

2 files changed

+54
-41
lines changed

2 files changed

+54
-41
lines changed

tests/e2e/multicard/test_dynamic_npugraph_batchsize.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#
1717
import pytest
1818
import torch
19-
from vllm import LLM, SamplingParams
19+
from vllm import SamplingParams
20+
21+
from tests.e2e.conftest import VllmRunner
2022

2123
MODELS = [
2224
"Qwen/Qwen2.5-0.5B-Instruct",
@@ -38,20 +40,20 @@
3840
def test_models(model: str, tp_size: int, max_tokens: int, temperature: int,
3941
ignore_eos: bool) -> None:
4042
# Create an LLM.
41-
llm = LLM(
42-
model=model,
43-
tensor_parallel_size=tp_size,
44-
)
45-
# Prepare sampling_parames
46-
sampling_params = SamplingParams(
47-
max_tokens=max_tokens,
48-
temperature=temperature,
49-
ignore_eos=ignore_eos,
50-
)
43+
with VllmRunner(
44+
model_name=model,
45+
tensor_parallel_size=tp_size,
46+
) as vllm_model:
47+
# Prepare sampling_parames
48+
sampling_params = SamplingParams(
49+
max_tokens=max_tokens,
50+
temperature=temperature,
51+
ignore_eos=ignore_eos,
52+
)
5153

52-
# Generate texts from the prompts.
53-
# The output is a list of RequestOutput objects
54-
outputs = llm.generate(prompts, sampling_params)
55-
torch.npu.synchronize()
56-
# The output length should be equal to prompts length.
57-
assert len(outputs) == len(prompts)
54+
# Generate texts from the prompts.
55+
# The output is a list of RequestOutput objects
56+
outputs = vllm_model.generate(prompts, sampling_params)
57+
torch.npu.synchronize()
58+
# The output length should be equal to prompts length.
59+
assert len(outputs) == len(prompts)

tests/e2e/multicard/test_pyhccl_distributed.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,39 @@
2424
init_distributed_environment)
2525
from vllm.utils import update_environment_variables
2626

27+
from tests.e2e.conftest import cleanup_dist_env_and_memory
2728
from vllm_ascend.distributed.device_communicators.pyhccl import \
2829
PyHcclCommunicator
2930

31+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
32+
33+
multiprocessing.set_start_method("spawn", force=True)
34+
35+
36+
def _worker_entry(env, fn):
37+
# `multiprocessing.Process` cannot accept environment variables directly
38+
# so we need to pass the environment variables as arguments
39+
# and update the environment variables in the function
40+
update_environment_variables(env)
41+
42+
rank = int(os.environ['RANK'])
43+
local_rank = int(os.environ['LOCAL_RANK'])
44+
word_size = int(os.environ['WORLD_SIZE'])
45+
46+
distributed_init_method = "tcp://localhost:12345"
47+
48+
device = torch.device(f"npu:{local_rank}")
49+
torch.npu.set_device(device)
50+
51+
init_distributed_environment(
52+
world_size=word_size,
53+
rank=rank,
54+
distributed_init_method=distributed_init_method,
55+
local_rank=local_rank,
56+
backend="hccl")
57+
fn()
58+
cleanup_dist_env_and_memory()
59+
3060

3161
def distributed_run(fn, world_size):
3262
number_of_processes = world_size
@@ -37,9 +67,7 @@ def distributed_run(fn, world_size):
3767
env['LOCAL_RANK'] = str(i)
3868
env['WORLD_SIZE'] = str(number_of_processes)
3969
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
40-
env['MASTER_ADDR'] = 'localhost'
41-
env['MASTER_PORT'] = '12345'
42-
p = multiprocessing.Process(target=fn, args=(env, ))
70+
p = multiprocessing.Process(target=_worker_entry, args=(env, fn))
4371
processes.append(p)
4472
p.start()
4573

@@ -50,22 +78,6 @@ def distributed_run(fn, world_size):
5078
assert p.exitcode == 0
5179

5280

53-
def worker_fn_wrapper(fn):
54-
# `multiprocessing.Process` cannot accept environment variables directly
55-
# so we need to pass the environment variables as arguments
56-
# and update the environment variables in the function
57-
def wrapped_fn(env):
58-
update_environment_variables(env)
59-
local_rank = os.environ['LOCAL_RANK']
60-
device = torch.device(f"npu:{local_rank}")
61-
torch.npu.set_device(device)
62-
init_distributed_environment(backend="hccl")
63-
fn()
64-
65-
return wrapped_fn
66-
67-
68-
@worker_fn_wrapper
6981
def worker_fn():
7082
pynccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
7183
device=get_world_group().device)
@@ -76,11 +88,10 @@ def worker_fn():
7688
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
7789

7890

79-
# def test_pyhccl():
80-
# distributed_run(worker_fn, 2)
91+
def test_pyhccl():
92+
distributed_run(worker_fn, 4)
8193

8294

83-
@worker_fn_wrapper
8495
def broadcast_worker_fn():
8596
# Test broadcast for every root rank.
8697
# Essentially this is an all-gather operation.
@@ -106,5 +117,5 @@ def broadcast_worker_fn():
106117
assert torch.all(recv_tensors[i] == i).cpu().item()
107118

108119

109-
# def test_pyhccl_broadcast():
110-
# distributed_run(broadcast_worker_fn, 4)
120+
def test_pyhccl_broadcast():
121+
distributed_run(broadcast_worker_fn, 4)

0 commit comments

Comments
 (0)