Skip to content

Commit f128429

Browse files
wangxiaoxin-sherieShangwei-Li
andcommitted
add weight loader e2e test.
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: Shangwei-Li <lishangwei2@huawei.com>
1 parent 3e60aa5 commit f128429

File tree

3 files changed

+524
-0
lines changed

3 files changed

+524
-0
lines changed

.github/workflows/vllm_ascend_test_full.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ jobs:
210210
VLLM_WORKER_MULTIPROC_METHOD: spawn
211211
VLLM_USE_MODELSCOPE: True
212212
run: |
213+
pytest -sv tests/e2e/multicard/test_weight_loader.py
213214
pytest -sv tests/e2e/multicard/test_data_parallel.py
214215
pytest -sv tests/e2e/multicard/test_expert_parallel.py
215216
# external_launcher test is not stable enough. Fix it later

examples/offline_weight_load.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py
18+
19+
# Note: This script is designed to run with e2e test,
20+
# please be careful to modify it.
21+
"""
22+
Usage:
23+
Single node:
24+
Dense models:
25+
python examples/offline_weight_load.py \
26+
--model="Qwen/Qwen2.5-0.5B-Instruct" \
27+
--tp-size=1 \
28+
--proc-per-node=2
29+
MOE models:
30+
python examples/offline_weight_load.py \
31+
--model="Qwen/Qwen3-30B-A3B" \
32+
--tp-size=2 \
33+
--proc-per-node=2 \
34+
--enable-expert-parallel
35+
36+
Multi-node:
37+
Node 0 (assume the node has ip of 10.99.48.128):
38+
python examples/offline_weight_load.py \
39+
--model="Qwen/Qwen3-30B-A3B" \
40+
--tp-size=2 \
41+
--node-size=2 \
42+
--node-rank=0 \
43+
--proc-per-node=2 \
44+
--enable-expert-parallel \
45+
--master-addr=10.99.48.128 \
46+
--master-port=13345
47+
Node 1:
48+
python examples/offline_weight_load.py \
49+
--model="Qwen/Qwen3-30B-A3B" \
50+
--tp-size=2 \
51+
--node-size=2 \
52+
--node-rank=1 \
53+
--enable-expert-parallel \
54+
--master-addr=10.99.48.128 \
55+
--master-port=13345
56+
"""
57+
58+
import argparse
59+
import contextlib
60+
import gc
61+
import os
62+
from multiprocessing import Process
63+
from time import sleep
64+
65+
import torch
66+
from vllm import LLM, SamplingParams
67+
from vllm.distributed.parallel_state import ( # noqa E402
68+
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
69+
from vllm.utils import get_open_port, GiB_bytes
70+
from safetensors.torch import load_file
71+
72+
os.environ["VLLM_USE_MODELSCOPE"] = "True"
73+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
74+
75+
def patch_vllm_moe_model_weight_loader(model):
76+
# Define MLP attribute mapping for different model types
77+
78+
model = getattr(model, "model", None) or getattr(model, "language_model", None)
79+
if model is None:
80+
raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.")
81+
82+
for layer in model.layers:
83+
mlp_attr = "mlp"
84+
mlp = getattr(layer, mlp_attr)
85+
86+
param_dict = dict(mlp.named_parameters())
87+
for name, param in param_dict.items():
88+
if "w13_weight" in name or "w2_weight" in name:
89+
param.weight_loader = mlp.experts.weight_loader
90+
91+
def load_and_merge_safetensors(directory):
92+
merged_dict = {}
93+
94+
if not os.path.isdir(directory):
95+
raise ValueError(f"directory is not exist : {directory}")
96+
97+
for filename in os.listdir(directory):
98+
if filename.endswith('.safetensors'):
99+
file_path = os.path.join(directory, filename)
100+
print(f"loading file: {file_path}")
101+
102+
f = load_file(file_path)
103+
merged_dict.update(f)
104+
105+
return merged_dict
106+
107+
def parse_args():
108+
109+
parser = argparse.ArgumentParser(description="External launcher Inference")
110+
parser.add_argument(
111+
"--model",
112+
type=str,
113+
default="Qwen/Qwen3-0.6B",
114+
help="Model name or path",
115+
)
116+
parser.add_argument("--tp-size",
117+
type=int,
118+
default=1,
119+
help="Tensor parallel size")
120+
parser.add_argument("--node-size",
121+
type=int,
122+
default=1,
123+
help="Total number of nodes")
124+
parser.add_argument("--node-rank",
125+
type=int,
126+
default=0,
127+
help="Rank of the current node")
128+
parser.add_argument("--proc-per-node",
129+
type=int,
130+
default=1,
131+
help="Number of processes per node")
132+
parser.add_argument("--master-addr",
133+
type=str,
134+
default="",
135+
help="Master node IP address")
136+
parser.add_argument("--master-port",
137+
type=int,
138+
default=0,
139+
help="Master node port")
140+
parser.add_argument("--enforce-eager",
141+
action="store_true",
142+
help="Enforce eager mode execution.")
143+
parser.add_argument("--trust-remote-code",
144+
action="store_true",
145+
help="Trust remote code.")
146+
parser.add_argument("--enable-expert-parallel",
147+
action="store_true",
148+
help="Enable expert parallel, used in MOE models.")
149+
parser.add_argument("--enable-sleep-mode",
150+
action="store_true",
151+
help="Enable sleep mode for the engine.")
152+
parser.add_argument("--temperature",
153+
type=float,
154+
default=0.8,
155+
help="Float that controls the randomness of the sampling.")
156+
parser.add_argument("--model-weight-gib",
157+
type=float,
158+
default=None,
159+
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).")
160+
161+
args = parser.parse_args()
162+
if args.enable_sleep_mode:
163+
if args.model_weight_gib is None or args.temperature != 0:
164+
parser.error("model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set.")
165+
if args.model_weight_gib <= 0:
166+
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
167+
if args.model == parser.get_default("model") and args.model_weight_gib is None:
168+
parser.error("model-weight-gib must be provided for default model when enable-sleep-mode is set.")
169+
170+
return args
171+
172+
173+
def main(
174+
local_rank: int,
175+
rank: int,
176+
master_addr: str,
177+
master_port: int,
178+
model_weight_gib: float,
179+
model: str = "Qwen/Qwen3-30B-A3B",
180+
world_size: int = 4,
181+
tensor_parallel_size: int = 2,
182+
enable_expert_parallel: bool = False,
183+
enforce_eager: bool = True,
184+
trust_remote_code: bool = True,
185+
enable_sleep_mode: bool = False,
186+
temperature: float = 0.8,
187+
):
188+
os.environ["MASTER_ADDR"] = master_addr
189+
os.environ["MASTER_PORT"] = str(master_port)
190+
os.environ["RANK"] = str(rank)
191+
os.environ["LOCAL_RANK"] = str(local_rank)
192+
os.environ["WORLD_SIZE"] = str(world_size)
193+
if not torch.distributed.is_initialized():
194+
torch.distributed.init_process_group(
195+
backend="cpu:gloo,npu:hccl",
196+
world_size=world_size,
197+
rank=rank,
198+
)
199+
prompts = [
200+
"Hello, my name is",
201+
"The president of the United States is",
202+
"The capital of France is",
203+
"The future of AI is",
204+
] * 10
205+
sampling_params = SamplingParams(
206+
temperature=temperature,
207+
top_p=0.95,
208+
max_tokens=10,
209+
)
210+
llm = LLM(
211+
model=model,
212+
tensor_parallel_size=tensor_parallel_size,
213+
enable_expert_parallel=enable_expert_parallel,
214+
enforce_eager=enforce_eager,
215+
trust_remote_code=trust_remote_code,
216+
distributed_executor_backend="external_launcher",
217+
seed=0,
218+
gpu_memory_utilization = 0.95,
219+
enable_sleep_mode=enable_sleep_mode,
220+
)
221+
model_path = model
222+
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
223+
patch_vllm_moe_model_weight_loader(runmodel)
224+
sd = load_and_merge_safetensors(model_path)
225+
runmodel.load_weights(sd.items())
226+
print('load state dict done')
227+
tp_ranks = get_tp_group().ranks
228+
print(f'TP RANKS: {tp_ranks}')
229+
230+
outputs = llm.generate(prompts, sampling_params)
231+
232+
if enable_sleep_mode:
233+
if rank == 0:
234+
free_bytes_before_sleep, total = torch.npu.mem_get_info()
235+
llm.sleep(level=1)
236+
if rank == 0:
237+
free_bytes_after_sleep, total = torch.npu.mem_get_info()
238+
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
239+
print(f"Freed memory: {freed_bytes / 1024 ** 3:.2f} GiB")
240+
# now the freed memory should be larger than the model weights
241+
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
242+
243+
llm.wake_up()
244+
outputs_after_wakeup = llm.generate(prompts, sampling_params)
245+
if rank == 0:
246+
# cmp output
247+
assert outputs[0].outputs[0].text == outputs_after_wakeup[0].outputs[0].text
248+
print("Sleep and wake up successfully!!")
249+
250+
for i, output in enumerate(outputs):
251+
if i >= 5:
252+
# print only 5 outputs
253+
break
254+
prompt = output.prompt
255+
generated_text = output.outputs[0].text
256+
print(f"Global rank: {rank}, Prompt: {prompt!r}, "
257+
f"Generated text: {generated_text!r}")
258+
259+
# Give engines time to pause their processing loops before exiting.
260+
sleep(5)
261+
del llm
262+
cleanup_env_and_memory()
263+
264+
265+
def cleanup_env_and_memory():
266+
destroy_model_parallel()
267+
destroy_distributed_environment()
268+
with contextlib.suppress(AssertionError):
269+
torch.distributed.destroy_process_group()
270+
gc.collect()
271+
torch.npu.empty_cache()
272+
torch.npu.reset_peak_memory_stats()
273+
274+
275+
if __name__ == "__main__":
276+
args = parse_args()
277+
278+
tp_size = args.tp_size
279+
node_size = args.node_size
280+
proc_per_node = args.proc_per_node
281+
node_rank = args.node_rank
282+
283+
if node_size == 1:
284+
master_addr = "127.0.0.1"
285+
master_port = get_open_port()
286+
else:
287+
master_addr = args.master_addr
288+
master_port = args.master_port
289+
290+
world_size = node_size * proc_per_node
291+
292+
procs = []
293+
for local_rank, rank in enumerate(
294+
range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))):
295+
proc = Process(target=main,
296+
args=(
297+
local_rank,
298+
rank,
299+
master_addr,
300+
master_port,
301+
args.model_weight_gib,
302+
args.model,
303+
world_size,
304+
tp_size,
305+
args.enable_expert_parallel,
306+
args.enforce_eager,
307+
args.trust_remote_code,
308+
args.enable_sleep_mode,
309+
args.temperature,
310+
))
311+
312+
proc.start()
313+
procs.append(proc)
314+
exit_code = 0
315+
for proc in procs:
316+
proc.join(timeout=600)
317+
if proc.exitcode is None:
318+
print(
319+
f"Killing process {proc.pid} that didn't stop within 30 minutes."
320+
)
321+
proc.kill()
322+
exit_code = 1
323+
elif proc.exitcode:
324+
exit_code = proc.exitcode
325+
326+
exit(exit_code)

0 commit comments

Comments
 (0)