18
18
# limitations under the License.
19
19
"""Inference-only LLaMA model compatible with HuggingFace weights."""
20
20
from collections .abc import Iterable
21
- from typing import Any , Optional
21
+ from typing import Any , Optional , Union
22
22
23
23
import torch
24
24
from torch import nn
@@ -335,6 +335,8 @@ def __init__(self,
335
335
self .num_experts = vllm_config .model_config .hf_config .num_local_experts
336
336
self .enable_eplb = vllm_config .parallel_config .enable_eplb if hasattr (
337
337
vllm_config .parallel_config , 'enable_eplb' ) else False
338
+ self ._num_redundant_experts = getattr (
339
+ vllm_config .parallel_config , 'num_redundant_experts' , 0 )
338
340
339
341
# Create layers with enable_eplb parameter
340
342
self .vllm_config = vllm_config
@@ -390,53 +392,69 @@ def load_moe_expert_weights(
390
392
fused : bool = True ,
391
393
) -> bool :
392
394
expert_param_loaded = False
395
+ is_expert_weight = False
396
+ loaded_weight_list = [loaded_weight ]
393
397
if "experts.gate_up_proj" in name :
394
- loaded_weight = loaded_weight .chunk (2 , dim = - 1 )
398
+ loaded_weight_list = list (loaded_weight .chunk (2 , dim = - 1 ))
399
+
395
400
for (param_name , weight_name , expert_id ,
396
401
shard_id ) in expert_params_mapping :
397
- new_loaded_weight = loaded_weight
402
+ if weight_name not in name :
403
+ continue
404
+
405
+ is_expert_weight = True
406
+ new_loaded_weight = loaded_weight_list [0 ]
407
+
398
408
if fused :
399
409
e_str , _ , proj_str , _ = weight_name .split ('.' )
400
410
weight_name = f"{ e_str } .{ proj_str } "
401
411
param_name = f"{ param_name } weight"
402
- if weight_name not in name :
403
- continue
412
+
404
413
full_param_name = name .replace (weight_name , param_name )
414
+
405
415
# Skip layers on other devices.
406
- if is_pp_missing_parameter (name , self ):
416
+ if is_pp_missing_parameter (full_param_name , self ):
407
417
continue
408
- if ((name .endswith (".bias" ) or name .endswith ("_bias" ))
409
- and name not in params_dict ):
418
+ if ((full_param_name .endswith (".bias" ) or full_param_name .endswith ("_bias" ))
419
+ and full_param_name not in params_dict ):
410
420
continue
421
+
411
422
param = params_dict [full_param_name ]
412
- weight_loader = param .weight_loader
423
+ weight_loader = getattr (param , "weight_loader" , default_weight_loader )
424
+
413
425
if fused :
414
426
if "w13" in full_param_name :
415
427
shard_idx = 0 if shard_id == "w1" else 1
416
- new_loaded_weight = new_loaded_weight [shard_idx ]
428
+ new_loaded_weight = loaded_weight_list [shard_idx ]
417
429
new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
418
- layer_idx = extract_layer_index (name )
419
- # EP mapping
420
- expert_map = self .layers [
421
- layer_idx ].feed_forward .experts .expert_map
422
- if expert_map is not None :
423
- local_expert_indices = (expert_map != - 1 ) \
424
- .nonzero () \
425
- .flatten () \
426
- .to (new_loaded_weight .device )
427
- new_loaded_weight = new_loaded_weight [local_expert_indices ]
428
- expert_id = local_expert_indices [0 ].item ()
429
- else :
430
- # TODO: add EP support for non fused weights
431
- pass
432
- weight_loader (param ,
433
- new_loaded_weight ,
434
- full_param_name ,
435
- shard_id = shard_id ,
436
- expert_id = expert_id )
437
-
438
- loaded_params .add (full_param_name )
439
- expert_param_loaded = True
430
+
431
+ # Use return_success to check if weight loading succeeded
432
+ if hasattr (weight_loader , '__call__' ):
433
+ # Check if weight_loader supports return_success parameter
434
+ import inspect
435
+ sig = inspect .signature (weight_loader )
436
+ if 'return_success' in sig .parameters :
437
+ success = weight_loader (param ,
438
+ new_loaded_weight ,
439
+ full_param_name ,
440
+ shard_id = shard_id ,
441
+ expert_id = expert_id ,
442
+ return_success = True )
443
+ if success :
444
+ loaded_params .add (full_param_name )
445
+ expert_param_loaded = True
446
+ break
447
+ else :
448
+ # Fallback for weight loaders without return_success
449
+ weight_loader (param ,
450
+ new_loaded_weight ,
451
+ full_param_name ,
452
+ shard_id = shard_id ,
453
+ expert_id = expert_id )
454
+ loaded_params .add (full_param_name )
455
+ expert_param_loaded = True
456
+ break
457
+
440
458
return expert_param_loaded
441
459
442
460
def load_weights (self , weights : Iterable [tuple [str ,
@@ -450,16 +468,19 @@ def load_weights(self, weights: Iterable[tuple[str,
450
468
(".gate_up_proj" , ".up_proj" , 1 ),
451
469
]
452
470
fused_experts_params = False
471
+ # Pass num_redundant_experts for EPLB support
453
472
expert_params_mapping = FusedMoE .make_expert_params_mapping (
454
473
ckpt_gate_proj_name = "gate_proj" ,
455
474
ckpt_down_proj_name = "down_proj" ,
456
475
ckpt_up_proj_name = "up_proj" ,
457
- num_experts = self .num_experts )
476
+ num_experts = self .num_experts ,
477
+ num_redundant_experts = self ._num_redundant_experts )
458
478
expert_params_mapping_fused = FusedMoE .make_expert_params_mapping (
459
479
ckpt_gate_proj_name = "gate_up_proj" ,
460
480
ckpt_down_proj_name = "down_proj" ,
461
481
ckpt_up_proj_name = "gate_up_proj" ,
462
- num_experts = 1 )
482
+ num_experts = 1 ,
483
+ num_redundant_experts = 0 )
463
484
params_dict = dict (self .named_parameters ())
464
485
loaded_params : set [str ] = set ()
465
486
for name , loaded_weight in weights :
@@ -531,7 +552,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
531
552
layer_type = Llama4DecoderLayer )
532
553
533
554
# For MixtureOfExperts protocol
534
- self ._num_redundant_experts = 0
555
+ parallel_config = vllm_config .parallel_config
556
+ self ._num_redundant_experts = getattr (parallel_config , 'num_redundant_experts' , 0 )
557
+ self .expert_weights = []
535
558
536
559
# Track FusedMoE layers for EPLB
537
560
self .moe_layers : list [FusedMoE ] = []
@@ -545,11 +568,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
545
568
layer .feed_forward , 'experts' ):
546
569
self .moe_layers .append (layer .feed_forward .experts )
547
570
548
- @property
549
- def expert_weights (self ) -> list [torch .nn .Module ]:
550
- """Get all MoE layers"""
551
- return self .moe_layers
552
-
553
571
@property
554
572
def num_moe_layers (self ) -> int :
555
573
"""Get number of MoE layers"""
@@ -595,15 +613,16 @@ def num_redundant_experts(self) -> int:
595
613
596
614
def set_eplb_state (
597
615
self ,
598
- moe_layer_indices : list [int ],
599
616
expert_load_view : torch .Tensor ,
600
617
logical_to_physical_map : torch .Tensor ,
601
618
logical_replica_count : torch .Tensor ,
602
619
) -> None :
603
620
"""Set EPLB state for MoE layers"""
604
- for i , moe_layer_idx in enumerate (moe_layer_indices ):
605
- self .moe_layers [i ].set_eplb_state (
606
- moe_layer_idx = i ,
621
+ for layer_idx , layer in enumerate (self .moe_layers ):
622
+ # Register the expert weights.
623
+ self .expert_weights .append (layer .get_expert_weights ())
624
+ layer .set_eplb_state (
625
+ moe_layer_idx = layer_idx ,
607
626
expert_load_view = expert_load_view ,
608
627
logical_to_physical_map = logical_to_physical_map ,
609
628
logical_replica_count = logical_replica_count ,
0 commit comments