17
17
# See the License for the specific language governing permissions and
18
18
# limitations under the License.
19
19
"""Inference-only LLaMA model compatible with HuggingFace weights."""
20
- from collections .abc import Iterable
20
+ from collections .abc import Callable , Iterable
21
21
from typing import Any , Optional , Union
22
+ import typing
22
23
23
24
import torch
24
25
from torch import nn
25
26
from transformers import Llama4TextConfig
26
27
27
28
from vllm .attention import Attention
28
29
from vllm .compilation .decorators import support_torch_compile
29
- from vllm .config import CacheConfig , VllmConfig
30
- from vllm .distributed import get_tensor_model_parallel_world_size
30
+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
31
+ from vllm .distributed import get_ep_group , get_tensor_model_parallel_world_size
31
32
from vllm .model_executor .layers .fused_moe import FusedMoE
32
33
from vllm .model_executor .layers .layernorm import RMSNorm
33
34
from vllm .model_executor .layers .linear import (QKVParallelLinear ,
@@ -66,15 +67,35 @@ def __init__(self,
66
67
self .tp_size = get_tensor_model_parallel_world_size ()
67
68
self .top_k = config .num_experts_per_tok
68
69
70
+ # EP group support
71
+ self .ep_group = get_ep_group ().device_group
72
+ self .ep_rank = self .ep_group .rank ()
73
+ self .ep_size = self .ep_group .size ()
74
+
75
+ # Get EPLB settings from current vllm config
76
+ vllm_config = get_current_vllm_config ()
77
+ parallel_config = vllm_config .parallel_config
78
+ self .enable_eplb = enable_eplb
79
+
80
+ # Calculate expert distribution
81
+ self .n_logical_experts = config .num_local_experts
82
+ self .n_redundant_experts = parallel_config .num_redundant_experts
83
+ self .n_physical_experts = self .n_logical_experts + self .n_redundant_experts
84
+ self .n_local_physical_experts = self .n_physical_experts // self .ep_size
85
+
86
+ # Calculate which experts belong to this rank
87
+ self .physical_expert_start = self .ep_rank * self .n_local_physical_experts
88
+ self .physical_expert_end = self .physical_expert_start + self .n_local_physical_experts
89
+
69
90
intermediate_size_moe = config .intermediate_size
70
91
self .router = ReplicatedLinear (config .hidden_size ,
71
- config . num_local_experts ,
92
+ self . n_logical_experts ,
72
93
bias = False ,
73
94
quant_config = None ,
74
95
prefix = f"{ prefix } .router" )
75
96
76
97
self .experts = FusedMoE (
77
- num_experts = config . num_local_experts ,
98
+ num_experts = self . n_logical_experts ,
78
99
top_k = config .num_experts_per_tok ,
79
100
hidden_size = config .hidden_size ,
80
101
custom_routing_function = Llama4MoE .custom_routing_function ,
@@ -84,7 +105,8 @@ def __init__(self,
84
105
renormalize = False ,
85
106
quant_config = quant_config ,
86
107
prefix = f"{ prefix } .experts" ,
87
- enable_eplb = enable_eplb )
108
+ enable_eplb = self .enable_eplb ,
109
+ num_redundant_experts = self .n_redundant_experts )
88
110
89
111
self .shared_expert = LlamaMLP (
90
112
hidden_size = config .hidden_size ,
@@ -280,6 +302,7 @@ def __init__(
280
302
)
281
303
is_moe_layer = config .interleave_moe_layer_step > 0 and (
282
304
self .layer_idx + 1 ) % config .interleave_moe_layer_step == 0
305
+ self .feed_forward : Union [Llama4MoE , LlamaMLP ]
283
306
if is_moe_layer :
284
307
self .feed_forward = Llama4MoE (
285
308
config = config ,
@@ -467,26 +490,19 @@ def load_weights(self, weights: Iterable[tuple[str,
467
490
(".gate_up_proj" , ".gate_proj" , 0 ),
468
491
(".gate_up_proj" , ".up_proj" , 1 ),
469
492
]
470
- fused_experts_params = False
493
+
471
494
# Pass num_redundant_experts for EPLB support
472
495
expert_params_mapping = FusedMoE .make_expert_params_mapping (
473
496
ckpt_gate_proj_name = "gate_proj" ,
474
497
ckpt_down_proj_name = "down_proj" ,
475
498
ckpt_up_proj_name = "up_proj" ,
476
499
num_experts = self .num_experts ,
477
500
num_redundant_experts = self ._num_redundant_experts )
478
- expert_params_mapping_fused = FusedMoE .make_expert_params_mapping (
479
- ckpt_gate_proj_name = "gate_up_proj" ,
480
- ckpt_down_proj_name = "down_proj" ,
481
- ckpt_up_proj_name = "gate_up_proj" ,
482
- num_experts = 1 ,
483
- num_redundant_experts = 0 )
501
+
484
502
params_dict = dict (self .named_parameters ())
485
503
loaded_params : set [str ] = set ()
504
+
486
505
for name , loaded_weight in weights :
487
- if "experts.gate_up_proj" in name or "experts.down_proj" in name :
488
- fused_experts_params = True
489
- expert_params_mapping = expert_params_mapping_fused
490
506
if (self .quant_config is not None and
491
507
(scale_name := self .quant_config .get_cache_scale (name ))):
492
508
# Loading kv cache quantization scales
@@ -498,34 +514,93 @@ def load_weights(self, weights: Iterable[tuple[str,
498
514
weight_loader (param , loaded_weight )
499
515
loaded_params .add (scale_name )
500
516
continue
517
+
501
518
for param_name , weight_name , shard_id in stacked_params_mapping :
502
519
if weight_name not in name or "experts" in name :
503
520
continue
504
521
name = name .replace (weight_name , param_name )
505
522
if is_pp_missing_parameter (name , self ):
506
523
continue
507
524
param = params_dict [name ]
508
- weight_loader = param .weight_loader
525
+ weight_loader = getattr (param , "weight_loader" ,
526
+ default_weight_loader )
509
527
weight_loader (param , loaded_weight , shard_id )
510
528
loaded_params .add (name )
511
529
break
512
530
else :
513
- moe_loaded = self .load_moe_expert_weights (
514
- name ,
515
- loaded_weight ,
516
- params_dict ,
517
- loaded_params ,
518
- expert_params_mapping ,
519
- fused = fused_experts_params )
520
-
521
- if not moe_loaded :
531
+ # Check if this is an expert weight
532
+ is_expert_weight = False
533
+ for mapping in expert_params_mapping :
534
+ param_name , weight_name , expert_id , shard_id = mapping
535
+ if weight_name not in name :
536
+ continue
537
+
538
+ # This is an expert weight
539
+ is_expert_weight = True
540
+
541
+ # Create mapped name without modifying original
542
+ name_mapped = name .replace (weight_name , param_name )
543
+
544
+ if is_pp_missing_parameter (name_mapped , self ):
545
+ continue
546
+
547
+ # Skip bias if not in params
548
+ if ((name_mapped .endswith (".bias" ) or name_mapped .endswith ("_bias" ))
549
+ and name_mapped not in params_dict ):
550
+ continue
551
+
552
+ # Handle fused weights transformation
553
+ if "experts.gate_up_proj" in name :
554
+ loaded_weight_list = list (loaded_weight .chunk (2 , dim = - 1 ))
555
+ if "w13" in name_mapped and "w1" in shard_id :
556
+ loaded_weight_to_use = loaded_weight_list [0 ].transpose (- 1 , - 2 )
557
+ elif "w13" in name_mapped and "w3" in shard_id :
558
+ loaded_weight_to_use = loaded_weight_list [1 ].transpose (- 1 , - 2 )
559
+ else :
560
+ loaded_weight_to_use = loaded_weight .transpose (- 1 , - 2 )
561
+ else :
562
+ loaded_weight_to_use = loaded_weight .transpose (- 1 , - 2 ) if "experts" in name else loaded_weight
563
+
564
+ param = params_dict [name_mapped ]
565
+ weight_loader = typing .cast (Callable [..., bool ],
566
+ getattr (param , "weight_loader" , default_weight_loader ))
567
+
568
+ # Try to load with return_success
569
+ try :
570
+ success = weight_loader (
571
+ param ,
572
+ loaded_weight_to_use ,
573
+ name_mapped ,
574
+ shard_id = shard_id ,
575
+ expert_id = expert_id ,
576
+ return_success = True
577
+ )
578
+ if success :
579
+ loaded_params .add (name_mapped )
580
+ break
581
+ except TypeError :
582
+ # Fallback for weight loaders without return_success
583
+ weight_loader (
584
+ param ,
585
+ loaded_weight_to_use ,
586
+ name_mapped
587
+ )
588
+ loaded_params .add (name_mapped )
589
+ break
590
+ else :
591
+ if is_expert_weight :
592
+ # Expert weight identified but not local to this rank
593
+ continue
594
+
595
+ # Handle non-expert weights
522
596
if is_pp_missing_parameter (name , self ):
523
597
continue
524
598
param = params_dict [name ]
525
599
weight_loader = getattr (param , "weight_loader" ,
526
600
default_weight_loader )
527
601
weight_loader (param , loaded_weight )
528
602
loaded_params .add (name )
603
+
529
604
return loaded_params
530
605
531
606
@@ -568,48 +643,32 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
568
643
layer .feed_forward , 'experts' ):
569
644
self .moe_layers .append (layer .feed_forward .experts )
570
645
571
- @property
572
- def num_moe_layers (self ) -> int :
573
- """Get number of MoE layers"""
574
- return len (self .moe_layers )
575
-
576
- @property
577
- def num_expert_groups (self ) -> int :
578
- """Get number of expert groups (1 for Llama4)"""
579
- return 1
580
-
581
- @property
582
- def num_logical_experts (self ) -> int :
583
- """Get number of logical experts"""
584
- return self .config .num_local_experts
585
-
586
- @property
587
- def num_physical_experts (self ) -> int :
588
- """Get number of physical experts (includes redundant)"""
589
- return self .config .num_local_experts + self ._num_redundant_experts
590
-
591
- @property
592
- def num_local_physical_experts (self ) -> int :
593
- """Get number of local physical experts"""
594
- if len (self .moe_layers ) > 0 :
595
- return self .moe_layers [0 ].local_num_experts
596
- return self .config .num_local_experts
597
-
598
- @property
599
- def num_routed_experts (self ) -> int :
600
- """Get number of routed experts (excludes shared experts)"""
601
- return self .config .num_local_experts
602
-
603
- @property
604
- def num_shared_experts (self ) -> int :
605
- """Get number of shared experts"""
606
- # Llama4 has 1 shared expert per MoE layer
607
- return 1 if self .num_moe_layers > 0 else 0
608
-
609
- @property
610
- def num_redundant_experts (self ) -> int :
611
- """Get number of redundant experts"""
612
- return self ._num_redundant_experts
646
+ # Get expert counts from an actual MoE layer
647
+ example_moe = None
648
+ for layer_idx in range (config .num_hidden_layers ):
649
+ layer = self .model .layers [layer_idx ]
650
+ if isinstance (layer , Llama4DecoderLayer ) and hasattr (layer , 'feed_forward' ) and isinstance (layer .feed_forward , Llama4MoE ):
651
+ example_moe = layer .feed_forward
652
+ break
653
+
654
+ if example_moe is not None :
655
+ self .num_logical_experts = example_moe .n_logical_experts
656
+ self .num_physical_experts = example_moe .n_physical_experts
657
+ self .num_local_physical_experts = example_moe .n_local_physical_experts
658
+ self .num_routed_experts = example_moe .n_logical_experts
659
+ self .num_redundant_experts = example_moe .n_redundant_experts
660
+ else :
661
+ # Fallback values if no MoE layers
662
+ self .num_logical_experts = 0
663
+ self .num_physical_experts = 0
664
+ self .num_local_physical_experts = 0
665
+ self .num_routed_experts = 0
666
+ self .num_redundant_experts = 0
667
+
668
+ # Set MixtureOfExperts attributes
669
+ self .num_moe_layers = len (self .moe_layers )
670
+ self .num_expert_groups = 1 # Llama4 has 1 expert group
671
+ self .num_shared_experts = 1 if self .num_moe_layers > 0 else 0 # 1 shared expert per MoE layer
613
672
614
673
def set_eplb_state (
615
674
self ,
0 commit comments