@@ -533,15 +533,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
533
533
# For MixtureOfExperts protocol
534
534
self ._num_redundant_experts = 0
535
535
536
+ # Track FusedMoE layers for EPLB
537
+ self .moe_layers : list [FusedMoE ] = []
538
+ config = vllm_config .model_config .hf_config
539
+ for i , layer in enumerate (self .model .layers ):
540
+ if isinstance (layer , Llama4DecoderLayer ):
541
+ is_moe_layer = (config .interleave_moe_layer_step > 0
542
+ and (i + 1 ) % config .interleave_moe_layer_step
543
+ == 0 )
544
+ if is_moe_layer and hasattr (layer , 'feed_forward' ) and hasattr (
545
+ layer .feed_forward , 'experts' ):
546
+ self .moe_layers .append (layer .feed_forward .experts )
547
+
536
548
@property
537
549
def expert_weights (self ) -> list [torch .nn .Module ]:
538
550
"""Get all MoE layers"""
539
- return self .model . moe_layers
551
+ return self .moe_layers
540
552
541
553
@property
542
554
def num_moe_layers (self ) -> int :
543
555
"""Get number of MoE layers"""
544
- return len (self .model . moe_layers )
556
+ return len (self .moe_layers )
545
557
546
558
@property
547
559
def num_expert_groups (self ) -> int :
@@ -561,11 +573,8 @@ def num_physical_experts(self) -> int:
561
573
@property
562
574
def num_local_physical_experts (self ) -> int :
563
575
"""Get number of local physical experts"""
564
- if hasattr (self .model , 'moe_layers' ) and len (
565
- self .model .moe_layers ) > 0 :
566
- moe_layer = self .model .moe_layers [0 ]
567
- if hasattr (moe_layer .feed_forward , 'experts' ):
568
- return moe_layer .feed_forward .experts .local_num_experts
576
+ if len (self .moe_layers ) > 0 :
577
+ return self .moe_layers [0 ].local_num_experts
569
578
return self .config .num_local_experts
570
579
571
580
@property
@@ -593,22 +602,18 @@ def set_eplb_state(
593
602
) -> None :
594
603
"""Set EPLB state for MoE layers"""
595
604
for i , moe_layer_idx in enumerate (moe_layer_indices ):
596
- moe_layer = self .model .moe_layers [i ]
597
- if hasattr (moe_layer .feed_forward , 'experts' ):
598
- moe_layer .feed_forward .experts .set_eplb_state (
599
- moe_layer_idx = i ,
600
- expert_load_view = expert_load_view ,
601
- logical_to_physical_map = logical_to_physical_map ,
602
- logical_replica_count = logical_replica_count ,
603
- )
605
+ self .moe_layers [i ].set_eplb_state (
606
+ moe_layer_idx = i ,
607
+ expert_load_view = expert_load_view ,
608
+ logical_to_physical_map = logical_to_physical_map ,
609
+ logical_replica_count = logical_replica_count ,
610
+ )
604
611
605
612
def get_expert_weights (self ) -> list [list [torch .Tensor ]]:
606
613
"""Get expert weights from all MoE layers"""
607
614
expert_weights = []
608
- for moe_layer in self .model .moe_layers :
609
- if hasattr (moe_layer .feed_forward , 'experts' ):
610
- expert_weights .append (
611
- list (moe_layer .feed_forward .experts .get_expert_weights ()))
615
+ for moe_layer in self .moe_layers :
616
+ expert_weights .append (list (moe_layer .get_expert_weights ()))
612
617
return expert_weights
613
618
614
619
def _init_model (self ,
0 commit comments