Skip to content

Commit 933b42c

Browse files
committed
add reload weight
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent b1380f3 commit 933b42c

File tree

3 files changed

+92
-15
lines changed

3 files changed

+92
-15
lines changed

examples/demo.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from vllm import LLM, RequestOutput, SamplingParams
2+
3+
# Sample prompts.
4+
prompts = [
5+
"Hello, my name is",
6+
"The president of the United States is",
7+
"The capital of France is",
8+
"The future of AI is",
9+
]
10+
# Create a sampling params object.
11+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
12+
13+
14+
def print_prompts_and_outputs(outputs: list[RequestOutput]) -> None:
15+
print("-" * 60)
16+
for output in outputs:
17+
prompt = output.prompt
18+
generated_text = output.outputs[0].text
19+
print(f"Prompt: {prompt!r}")
20+
print(f"Output: {generated_text!r}")
21+
print("-" * 60)
22+
23+
24+
def main():
25+
# Create an LLM without loading real weights
26+
llm = LLM(
27+
model="Qwen/Qwen3-0.6B",
28+
load_format="dummy",
29+
enforce_eager=True,
30+
tensor_parallel_size=4,
31+
)
32+
outputs = llm.generate(prompts, sampling_params)
33+
print("\nOutputs do not make sense:")
34+
print_prompts_and_outputs(outputs)
35+
36+
# Update load format from `dummy` to `auto`
37+
llm.collective_rpc(
38+
"update_config", args=({"load_config": {"load_format": "auto"}},)
39+
)
40+
# Now reload real weights inplace
41+
llm.collective_rpc("reload_weights")
42+
43+
# Check outputs make sense
44+
outputs = llm.generate(prompts, sampling_params)
45+
print("\nOutputs make sense after loading real weights:")
46+
print_prompts_and_outputs(outputs)
47+
48+
49+
if __name__ == "__main__":
50+
main()

vllm_ascend/worker/model_runner_v1.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.compilation.counter import compilation_counter
4444
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
4545
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
46-
get_layers_from_vllm_config)
46+
get_layers_from_vllm_config, update_config)
4747
from vllm.distributed import tensor_model_parallel_all_gather
4848
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
4949
has_kv_transfer_group)
@@ -56,7 +56,7 @@
5656
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
5757
from vllm.model_executor.layers.mamba.abstract import MambaBase
5858
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
59-
from vllm.model_executor.model_loader import get_model
59+
from vllm.model_executor.model_loader import get_model_loader
6060
from vllm.model_executor.models.interfaces import supports_transcription
6161
from vllm.model_executor.models.interfaces_base import (
6262
VllmModelForPooling, is_pooling_model, is_text_generation_model)
@@ -852,6 +852,24 @@ def get_model(self) -> nn.Module:
852852
return self.model.unwrap()
853853
return self.model
854854

855+
def update_config(self, overrides: dict[str, Any]) -> None:
856+
allowed_config_names = {"load_config", "model_config"}
857+
for config_name, config_overrides in overrides.items():
858+
assert config_name in allowed_config_names, \
859+
f"Config `{config_name}` not supported. " \
860+
f"Allowed configs: {allowed_config_names}"
861+
config = getattr(self, config_name)
862+
new_config = update_config(config, config_overrides)
863+
setattr(self, config_name, new_config)
864+
865+
def reload_weights(self) -> None:
866+
assert getattr(self, "model", None) is not None, \
867+
"Cannot reload weights before model is loaded."
868+
model_loader = get_model_loader(self.load_config)
869+
logger.info("Reloading weights inplace...")
870+
model = self.get_model()
871+
model_loader.load_weights(model, model_config=self.model_config)
872+
855873
def get_supported_generation_tasks(self) -> "list[GenerationTask]":
856874
model = self.get_model()
857875
supported_tasks = list[GenerationTask]()
@@ -2593,9 +2611,23 @@ def load_model(self) -> None:
25932611
logger.info("Starting to load model %s...", self.model_config.model)
25942612

25952613
with DeviceMemoryProfiler() as m: # noqa: SIM117
2596-
self.model = get_model(vllm_config=self.vllm_config)
2614+
model_loader = get_model_loader(self.load_config)
2615+
logger.info("Loading model from scratch...")
2616+
self.model = model_loader.load_model(
2617+
vllm_config=self.vllm_config, model_config=self.model_config)
25972618
if self.dynamic_eplb:
25982619
model_register(self.model, self.model_config)
2620+
if self.lora_config:
2621+
if vllm_version_is("0.10.2"):
2622+
self.model = self.load_lora_model(self.model,
2623+
self.model_config,
2624+
self.scheduler_config,
2625+
self.lora_config,
2626+
self.device)
2627+
else:
2628+
self.model = self.load_lora_model(self.model,
2629+
self.vllm_config,
2630+
self.device)
25992631
if is_310p():
26002632
from vllm.model_executor.layers.linear import (
26012633
MergedColumnParallelLinear, QKVParallelLinear,
@@ -2613,17 +2645,6 @@ def load_model(self) -> None:
26132645
self.model.set_aux_hidden_state_layers(
26142646
self.model.get_eagle3_aux_hidden_state_layers())
26152647

2616-
if self.lora_config:
2617-
if vllm_version_is("0.10.2"):
2618-
self.model = self.load_lora_model(self.model,
2619-
self.model_config,
2620-
self.scheduler_config,
2621-
self.lora_config,
2622-
self.device)
2623-
else:
2624-
self.model = self.load_lora_model(self.model,
2625-
self.vllm_config,
2626-
self.device)
26272648
logger.info("Loading model weights took %.4f GB",
26282649
m.consumed_memory / float(2**30))
26292650

vllm_ascend/worker/worker_v1.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#
1919

2020
import copy
21-
from typing import Optional, Union
21+
from typing import Any, Optional, Union
2222

2323
import torch
2424
import torch.nn as nn
@@ -254,6 +254,12 @@ def execute_model(
254254
output.kv_connector_output = kv_connector_output
255255
return output
256256

257+
def update_config(self, overrides: dict[str, Any]) -> None:
258+
self.model_runner.update_config(overrides)
259+
260+
def reload_weights(self) -> None:
261+
self.model_runner.reload_weights()
262+
257263
def load_model(self) -> None:
258264
if self.vllm_config.model_config.enable_sleep_mode:
259265
allocator = CaMemAllocator.get_instance()

0 commit comments

Comments
 (0)