37
37
from vllm .model_executor .layers .rotary_embedding import get_rope
38
38
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
39
39
40
+ from .interfaces import MixtureOfExperts
40
41
from .llama import LlamaForCausalLM , LlamaMLP , LlamaModel
41
- from .utils import (AutoWeightsLoader , extract_layer_index , fast_topk ,
42
- is_pp_missing_parameter )
42
+ from .utils import (AutoWeightsLoader , PPMissingLayer , extract_layer_index ,
43
+ fast_topk , is_pp_missing_parameter )
43
44
44
45
45
46
class Llama4MoE (nn .Module ):
@@ -59,7 +60,8 @@ def custom_routing_function(
59
60
def __init__ (self ,
60
61
config : Llama4TextConfig ,
61
62
quant_config : Optional [QuantizationConfig ] = None ,
62
- prefix : str = "" ):
63
+ prefix : str = "" ,
64
+ enable_eplb : bool = False ):
63
65
super ().__init__ ()
64
66
self .tp_size = get_tensor_model_parallel_world_size ()
65
67
self .top_k = config .num_experts_per_tok
@@ -81,7 +83,8 @@ def __init__(self,
81
83
reduce_results = False ,
82
84
renormalize = False ,
83
85
quant_config = quant_config ,
84
- prefix = f"{ prefix } .experts" )
86
+ prefix = f"{ prefix } .experts" ,
87
+ enable_eplb = enable_eplb )
85
88
86
89
self .shared_expert = LlamaMLP (
87
90
hidden_size = config .hidden_size ,
@@ -251,6 +254,7 @@ def __init__(
251
254
cache_config : Optional [CacheConfig ] = None ,
252
255
quant_config : Optional [QuantizationConfig ] = None ,
253
256
prefix : str = "" ,
257
+ enable_eplb : bool = False ,
254
258
) -> None :
255
259
super ().__init__ ()
256
260
@@ -281,6 +285,7 @@ def __init__(
281
285
config = config ,
282
286
quant_config = quant_config ,
283
287
prefix = f"{ prefix } .feed_forward" ,
288
+ enable_eplb = enable_eplb ,
284
289
)
285
290
else :
286
291
self .feed_forward = LlamaMLP (
@@ -328,10 +333,53 @@ def __init__(self,
328
333
prefix : str = "" ,
329
334
layer_type : type [Llama4DecoderLayer ] = Llama4DecoderLayer ):
330
335
self .num_experts = vllm_config .model_config .hf_config .num_local_experts
336
+ self .enable_eplb = vllm_config .parallel_config .enable_eplb if hasattr (
337
+ vllm_config .parallel_config , 'enable_eplb' ) else False
338
+
339
+ # Create layers with enable_eplb parameter
340
+ self .vllm_config = vllm_config
341
+ self .layer_type = layer_type
342
+
331
343
super ().__init__ (vllm_config = vllm_config ,
332
344
prefix = prefix ,
333
345
layer_type = layer_type )
334
346
347
+ # Track MoE layers for EPLB
348
+ self .moe_layers = []
349
+ config = vllm_config .model_config .hf_config
350
+ for i , layer in enumerate (self .layers ):
351
+ if isinstance (layer , layer_type ):
352
+ is_moe_layer = (config .interleave_moe_layer_step > 0
353
+ and (i + 1 ) % config .interleave_moe_layer_step
354
+ == 0 )
355
+ if is_moe_layer :
356
+ self .moe_layers .append (layer )
357
+
358
+ def make_layers (
359
+ self ,
360
+ num_hidden_layers : int ,
361
+ layer_type : type [nn .Module ],
362
+ prefix : str ,
363
+ ) -> list [nn .Module ]:
364
+ """Override to pass enable_eplb to decoder layers."""
365
+ layers = []
366
+ for layer_idx in range (num_hidden_layers ):
367
+ if isinstance (self .start_layer ,
368
+ int ) and layer_idx < self .start_layer or (
369
+ isinstance (self .end_layer , int )
370
+ and layer_idx >= self .end_layer ):
371
+ layers .append (PPMissingLayer ())
372
+ else :
373
+ layer = layer_type (
374
+ config = self .config ,
375
+ cache_config = self .cache_config ,
376
+ quant_config = self .quant_config ,
377
+ prefix = f"{ prefix } .{ layer_idx } " ,
378
+ enable_eplb = self .enable_eplb ,
379
+ )
380
+ layers .append (layer )
381
+ return layers
382
+
335
383
def load_moe_expert_weights (
336
384
self ,
337
385
name : str ,
@@ -460,7 +508,7 @@ def load_weights(self, weights: Iterable[tuple[str,
460
508
return loaded_params
461
509
462
510
463
- class Llama4ForCausalLM (LlamaForCausalLM ):
511
+ class Llama4ForCausalLM (LlamaForCausalLM , MixtureOfExperts ):
464
512
465
513
packed_modules_mapping = {
466
514
"qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ],
@@ -482,6 +530,87 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
482
530
prefix = prefix ,
483
531
layer_type = Llama4DecoderLayer )
484
532
533
+ # For MixtureOfExperts protocol
534
+ self ._num_redundant_experts = 0
535
+
536
+ @property
537
+ def expert_weights (self ) -> list [torch .nn .Module ]:
538
+ """Get all MoE layers"""
539
+ return self .model .moe_layers
540
+
541
+ @property
542
+ def num_moe_layers (self ) -> int :
543
+ """Get number of MoE layers"""
544
+ return len (self .model .moe_layers )
545
+
546
+ @property
547
+ def num_expert_groups (self ) -> int :
548
+ """Get number of expert groups (1 for Llama4)"""
549
+ return 1
550
+
551
+ @property
552
+ def num_logical_experts (self ) -> int :
553
+ """Get number of logical experts"""
554
+ return self .config .num_local_experts
555
+
556
+ @property
557
+ def num_physical_experts (self ) -> int :
558
+ """Get number of physical experts (includes redundant)"""
559
+ return self .config .num_local_experts + self ._num_redundant_experts
560
+
561
+ @property
562
+ def num_local_physical_experts (self ) -> int :
563
+ """Get number of local physical experts"""
564
+ if hasattr (self .model , 'moe_layers' ) and len (
565
+ self .model .moe_layers ) > 0 :
566
+ moe_layer = self .model .moe_layers [0 ]
567
+ if hasattr (moe_layer .feed_forward , 'experts' ):
568
+ return moe_layer .feed_forward .experts .local_num_experts
569
+ return self .config .num_local_experts
570
+
571
+ @property
572
+ def num_routed_experts (self ) -> int :
573
+ """Get number of routed experts (excludes shared experts)"""
574
+ return self .config .num_local_experts
575
+
576
+ @property
577
+ def num_shared_experts (self ) -> int :
578
+ """Get number of shared experts"""
579
+ # Llama4 has 1 shared expert per MoE layer
580
+ return 1 if self .num_moe_layers > 0 else 0
581
+
582
+ @property
583
+ def num_redundant_experts (self ) -> int :
584
+ """Get number of redundant experts"""
585
+ return self ._num_redundant_experts
586
+
587
+ def set_eplb_state (
588
+ self ,
589
+ moe_layer_indices : list [int ],
590
+ expert_load_view : torch .Tensor ,
591
+ logical_to_physical_map : torch .Tensor ,
592
+ logical_replica_count : torch .Tensor ,
593
+ ) -> None :
594
+ """Set EPLB state for MoE layers"""
595
+ for i , moe_layer_idx in enumerate (moe_layer_indices ):
596
+ moe_layer = self .model .moe_layers [i ]
597
+ if hasattr (moe_layer .feed_forward , 'experts' ):
598
+ moe_layer .feed_forward .experts .set_eplb_state (
599
+ moe_layer_idx = i ,
600
+ expert_load_view = expert_load_view ,
601
+ logical_to_physical_map = logical_to_physical_map ,
602
+ logical_replica_count = logical_replica_count ,
603
+ )
604
+
605
+ def get_expert_weights (self ) -> list [list [torch .Tensor ]]:
606
+ """Get expert weights from all MoE layers"""
607
+ expert_weights = []
608
+ for moe_layer in self .model .moe_layers :
609
+ if hasattr (moe_layer .feed_forward , 'experts' ):
610
+ expert_weights .append (
611
+ list (moe_layer .feed_forward .experts .get_expert_weights ()))
612
+ return expert_weights
613
+
485
614
def _init_model (self ,
486
615
vllm_config : VllmConfig ,
487
616
prefix : str = "" ,
0 commit comments