43
43
from vllm .compilation .counter import compilation_counter
44
44
from vllm .compilation .monitor import set_cudagraph_capturing_enabled
45
45
from vllm .config import (CompilationLevel , CUDAGraphMode , VllmConfig ,
46
- get_layers_from_vllm_config )
46
+ get_layers_from_vllm_config , update_config )
47
47
from vllm .distributed import tensor_model_parallel_all_gather
48
48
from vllm .distributed .kv_transfer import (get_kv_transfer_group ,
49
49
has_kv_transfer_group )
56
56
from vllm .model_executor .layers .attention_layer_base import AttentionLayerBase
57
57
from vllm .model_executor .layers .mamba .abstract import MambaBase
58
58
from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
59
- from vllm .model_executor .model_loader import get_model
59
+ from vllm .model_executor .model_loader import get_model_loader
60
60
from vllm .model_executor .models .interfaces import supports_transcription
61
61
from vllm .model_executor .models .interfaces_base import (
62
62
VllmModelForPooling , is_pooling_model , is_text_generation_model )
@@ -852,6 +852,24 @@ def get_model(self) -> nn.Module:
852
852
return self .model .unwrap ()
853
853
return self .model
854
854
855
+ def update_config (self , overrides : dict [str , Any ]) -> None :
856
+ allowed_config_names = {"load_config" , "model_config" }
857
+ for config_name , config_overrides in overrides .items ():
858
+ assert config_name in allowed_config_names , \
859
+ f"Config `{ config_name } ` not supported. " \
860
+ f"Allowed configs: { allowed_config_names } "
861
+ config = getattr (self , config_name )
862
+ new_config = update_config (config , config_overrides )
863
+ setattr (self , config_name , new_config )
864
+
865
+ def reload_weights (self ) -> None :
866
+ assert getattr (self , "model" , None ) is not None , \
867
+ "Cannot reload weights before model is loaded."
868
+ model_loader = get_model_loader (self .load_config )
869
+ logger .info ("Reloading weights inplace..." )
870
+ model = self .get_model ()
871
+ model_loader .load_weights (model , model_config = self .model_config )
872
+
855
873
def get_supported_generation_tasks (self ) -> "list[GenerationTask]" :
856
874
model = self .get_model ()
857
875
supported_tasks = list [GenerationTask ]()
@@ -2593,9 +2611,23 @@ def load_model(self) -> None:
2593
2611
logger .info ("Starting to load model %s..." , self .model_config .model )
2594
2612
2595
2613
with DeviceMemoryProfiler () as m : # noqa: SIM117
2596
- self .model = get_model (vllm_config = self .vllm_config )
2614
+ model_loader = get_model_loader (self .load_config )
2615
+ logger .info ("Loading model from scratch..." )
2616
+ self .model = model_loader .load_model (
2617
+ vllm_config = self .vllm_config , model_config = self .model_config )
2597
2618
if self .dynamic_eplb :
2598
2619
model_register (self .model , self .model_config )
2620
+ if self .lora_config :
2621
+ if vllm_version_is ("0.10.2" ):
2622
+ self .model = self .load_lora_model (self .model ,
2623
+ self .model_config ,
2624
+ self .scheduler_config ,
2625
+ self .lora_config ,
2626
+ self .device )
2627
+ else :
2628
+ self .model = self .load_lora_model (self .model ,
2629
+ self .vllm_config ,
2630
+ self .device )
2599
2631
if is_310p ():
2600
2632
from vllm .model_executor .layers .linear import (
2601
2633
MergedColumnParallelLinear , QKVParallelLinear ,
@@ -2613,17 +2645,6 @@ def load_model(self) -> None:
2613
2645
self .model .set_aux_hidden_state_layers (
2614
2646
self .model .get_eagle3_aux_hidden_state_layers ())
2615
2647
2616
- if self .lora_config :
2617
- if vllm_version_is ("0.10.2" ):
2618
- self .model = self .load_lora_model (self .model ,
2619
- self .model_config ,
2620
- self .scheduler_config ,
2621
- self .lora_config ,
2622
- self .device )
2623
- else :
2624
- self .model = self .load_lora_model (self .model ,
2625
- self .vllm_config ,
2626
- self .device )
2627
2648
logger .info ("Loading model weights took %.4f GB" ,
2628
2649
m .consumed_memory / float (2 ** 30 ))
2629
2650
0 commit comments