Skip to content

Commit cea35dd

Browse files
committed
feat: enable llama4 EPLB
1 parent 13ede83 commit cea35dd

File tree

1 file changed

+127
-68
lines changed

1 file changed

+127
-68
lines changed

vllm/model_executor/models/llama4.py

Lines changed: 127 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,18 @@
1717
# See the License for the specific language governing permissions and
1818
# limitations under the License.
1919
"""Inference-only LLaMA model compatible with HuggingFace weights."""
20-
from collections.abc import Iterable
20+
from collections.abc import Callable, Iterable
2121
from typing import Any, Optional, Union
22+
import typing
2223

2324
import torch
2425
from torch import nn
2526
from transformers import Llama4TextConfig
2627

2728
from vllm.attention import Attention
2829
from vllm.compilation.decorators import support_torch_compile
29-
from vllm.config import CacheConfig, VllmConfig
30-
from vllm.distributed import get_tensor_model_parallel_world_size
30+
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
31+
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size
3132
from vllm.model_executor.layers.fused_moe import FusedMoE
3233
from vllm.model_executor.layers.layernorm import RMSNorm
3334
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@@ -66,15 +67,35 @@ def __init__(self,
6667
self.tp_size = get_tensor_model_parallel_world_size()
6768
self.top_k = config.num_experts_per_tok
6869

70+
# EP group support
71+
self.ep_group = get_ep_group().device_group
72+
self.ep_rank = self.ep_group.rank()
73+
self.ep_size = self.ep_group.size()
74+
75+
# Get EPLB settings from current vllm config
76+
vllm_config = get_current_vllm_config()
77+
parallel_config = vllm_config.parallel_config
78+
self.enable_eplb = enable_eplb
79+
80+
# Calculate expert distribution
81+
self.n_logical_experts = config.num_local_experts
82+
self.n_redundant_experts = parallel_config.num_redundant_experts
83+
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
84+
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
85+
86+
# Calculate which experts belong to this rank
87+
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
88+
self.physical_expert_end = self.physical_expert_start + self.n_local_physical_experts
89+
6990
intermediate_size_moe = config.intermediate_size
7091
self.router = ReplicatedLinear(config.hidden_size,
71-
config.num_local_experts,
92+
self.n_logical_experts,
7293
bias=False,
7394
quant_config=None,
7495
prefix=f"{prefix}.router")
7596

7697
self.experts = FusedMoE(
77-
num_experts=config.num_local_experts,
98+
num_experts=self.n_logical_experts,
7899
top_k=config.num_experts_per_tok,
79100
hidden_size=config.hidden_size,
80101
custom_routing_function=Llama4MoE.custom_routing_function,
@@ -84,7 +105,8 @@ def __init__(self,
84105
renormalize=False,
85106
quant_config=quant_config,
86107
prefix=f"{prefix}.experts",
87-
enable_eplb=enable_eplb)
108+
enable_eplb=self.enable_eplb,
109+
num_redundant_experts=self.n_redundant_experts)
88110

89111
self.shared_expert = LlamaMLP(
90112
hidden_size=config.hidden_size,
@@ -280,6 +302,7 @@ def __init__(
280302
)
281303
is_moe_layer = config.interleave_moe_layer_step > 0 and (
282304
self.layer_idx + 1) % config.interleave_moe_layer_step == 0
305+
self.feed_forward: Union[Llama4MoE, LlamaMLP]
283306
if is_moe_layer:
284307
self.feed_forward = Llama4MoE(
285308
config=config,
@@ -467,26 +490,19 @@ def load_weights(self, weights: Iterable[tuple[str,
467490
(".gate_up_proj", ".gate_proj", 0),
468491
(".gate_up_proj", ".up_proj", 1),
469492
]
470-
fused_experts_params = False
493+
471494
# Pass num_redundant_experts for EPLB support
472495
expert_params_mapping = FusedMoE.make_expert_params_mapping(
473496
ckpt_gate_proj_name="gate_proj",
474497
ckpt_down_proj_name="down_proj",
475498
ckpt_up_proj_name="up_proj",
476499
num_experts=self.num_experts,
477500
num_redundant_experts=self._num_redundant_experts)
478-
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
479-
ckpt_gate_proj_name="gate_up_proj",
480-
ckpt_down_proj_name="down_proj",
481-
ckpt_up_proj_name="gate_up_proj",
482-
num_experts=1,
483-
num_redundant_experts=0)
501+
484502
params_dict = dict(self.named_parameters())
485503
loaded_params: set[str] = set()
504+
486505
for name, loaded_weight in weights:
487-
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
488-
fused_experts_params = True
489-
expert_params_mapping = expert_params_mapping_fused
490506
if (self.quant_config is not None and
491507
(scale_name := self.quant_config.get_cache_scale(name))):
492508
# Loading kv cache quantization scales
@@ -498,34 +514,93 @@ def load_weights(self, weights: Iterable[tuple[str,
498514
weight_loader(param, loaded_weight)
499515
loaded_params.add(scale_name)
500516
continue
517+
501518
for param_name, weight_name, shard_id in stacked_params_mapping:
502519
if weight_name not in name or "experts" in name:
503520
continue
504521
name = name.replace(weight_name, param_name)
505522
if is_pp_missing_parameter(name, self):
506523
continue
507524
param = params_dict[name]
508-
weight_loader = param.weight_loader
525+
weight_loader = getattr(param, "weight_loader",
526+
default_weight_loader)
509527
weight_loader(param, loaded_weight, shard_id)
510528
loaded_params.add(name)
511529
break
512530
else:
513-
moe_loaded = self.load_moe_expert_weights(
514-
name,
515-
loaded_weight,
516-
params_dict,
517-
loaded_params,
518-
expert_params_mapping,
519-
fused=fused_experts_params)
520-
521-
if not moe_loaded:
531+
# Check if this is an expert weight
532+
is_expert_weight = False
533+
for mapping in expert_params_mapping:
534+
param_name, weight_name, expert_id, shard_id = mapping
535+
if weight_name not in name:
536+
continue
537+
538+
# This is an expert weight
539+
is_expert_weight = True
540+
541+
# Create mapped name without modifying original
542+
name_mapped = name.replace(weight_name, param_name)
543+
544+
if is_pp_missing_parameter(name_mapped, self):
545+
continue
546+
547+
# Skip bias if not in params
548+
if ((name_mapped.endswith(".bias") or name_mapped.endswith("_bias"))
549+
and name_mapped not in params_dict):
550+
continue
551+
552+
# Handle fused weights transformation
553+
if "experts.gate_up_proj" in name:
554+
loaded_weight_list = list(loaded_weight.chunk(2, dim=-1))
555+
if "w13" in name_mapped and "w1" in shard_id:
556+
loaded_weight_to_use = loaded_weight_list[0].transpose(-1, -2)
557+
elif "w13" in name_mapped and "w3" in shard_id:
558+
loaded_weight_to_use = loaded_weight_list[1].transpose(-1, -2)
559+
else:
560+
loaded_weight_to_use = loaded_weight.transpose(-1, -2)
561+
else:
562+
loaded_weight_to_use = loaded_weight.transpose(-1, -2) if "experts" in name else loaded_weight
563+
564+
param = params_dict[name_mapped]
565+
weight_loader = typing.cast(Callable[..., bool],
566+
getattr(param, "weight_loader", default_weight_loader))
567+
568+
# Try to load with return_success
569+
try:
570+
success = weight_loader(
571+
param,
572+
loaded_weight_to_use,
573+
name_mapped,
574+
shard_id=shard_id,
575+
expert_id=expert_id,
576+
return_success=True
577+
)
578+
if success:
579+
loaded_params.add(name_mapped)
580+
break
581+
except TypeError:
582+
# Fallback for weight loaders without return_success
583+
weight_loader(
584+
param,
585+
loaded_weight_to_use,
586+
name_mapped
587+
)
588+
loaded_params.add(name_mapped)
589+
break
590+
else:
591+
if is_expert_weight:
592+
# Expert weight identified but not local to this rank
593+
continue
594+
595+
# Handle non-expert weights
522596
if is_pp_missing_parameter(name, self):
523597
continue
524598
param = params_dict[name]
525599
weight_loader = getattr(param, "weight_loader",
526600
default_weight_loader)
527601
weight_loader(param, loaded_weight)
528602
loaded_params.add(name)
603+
529604
return loaded_params
530605

531606

@@ -568,48 +643,32 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
568643
layer.feed_forward, 'experts'):
569644
self.moe_layers.append(layer.feed_forward.experts)
570645

571-
@property
572-
def num_moe_layers(self) -> int:
573-
"""Get number of MoE layers"""
574-
return len(self.moe_layers)
575-
576-
@property
577-
def num_expert_groups(self) -> int:
578-
"""Get number of expert groups (1 for Llama4)"""
579-
return 1
580-
581-
@property
582-
def num_logical_experts(self) -> int:
583-
"""Get number of logical experts"""
584-
return self.config.num_local_experts
585-
586-
@property
587-
def num_physical_experts(self) -> int:
588-
"""Get number of physical experts (includes redundant)"""
589-
return self.config.num_local_experts + self._num_redundant_experts
590-
591-
@property
592-
def num_local_physical_experts(self) -> int:
593-
"""Get number of local physical experts"""
594-
if len(self.moe_layers) > 0:
595-
return self.moe_layers[0].local_num_experts
596-
return self.config.num_local_experts
597-
598-
@property
599-
def num_routed_experts(self) -> int:
600-
"""Get number of routed experts (excludes shared experts)"""
601-
return self.config.num_local_experts
602-
603-
@property
604-
def num_shared_experts(self) -> int:
605-
"""Get number of shared experts"""
606-
# Llama4 has 1 shared expert per MoE layer
607-
return 1 if self.num_moe_layers > 0 else 0
608-
609-
@property
610-
def num_redundant_experts(self) -> int:
611-
"""Get number of redundant experts"""
612-
return self._num_redundant_experts
646+
# Get expert counts from an actual MoE layer
647+
example_moe = None
648+
for layer_idx in range(config.num_hidden_layers):
649+
layer = self.model.layers[layer_idx]
650+
if isinstance(layer, Llama4DecoderLayer) and hasattr(layer, 'feed_forward') and isinstance(layer.feed_forward, Llama4MoE):
651+
example_moe = layer.feed_forward
652+
break
653+
654+
if example_moe is not None:
655+
self.num_logical_experts = example_moe.n_logical_experts
656+
self.num_physical_experts = example_moe.n_physical_experts
657+
self.num_local_physical_experts = example_moe.n_local_physical_experts
658+
self.num_routed_experts = example_moe.n_logical_experts
659+
self.num_redundant_experts = example_moe.n_redundant_experts
660+
else:
661+
# Fallback values if no MoE layers
662+
self.num_logical_experts = 0
663+
self.num_physical_experts = 0
664+
self.num_local_physical_experts = 0
665+
self.num_routed_experts = 0
666+
self.num_redundant_experts = 0
667+
668+
# Set MixtureOfExperts attributes
669+
self.num_moe_layers = len(self.moe_layers)
670+
self.num_expert_groups = 1 # Llama4 has 1 expert group
671+
self.num_shared_experts = 1 if self.num_moe_layers > 0 else 0 # 1 shared expert per MoE layer
613672

614673
def set_eplb_state(
615674
self,

0 commit comments

Comments
 (0)