Skip to content

Commit 2f46c26

Browse files
committed
feat: enable llama4 EPLB
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
1 parent cf75cd2 commit 2f46c26

File tree

1 file changed

+134
-5
lines changed

1 file changed

+134
-5
lines changed

vllm/model_executor/models/llama4.py

Lines changed: 134 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@
3737
from vllm.model_executor.layers.rotary_embedding import get_rope
3838
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3939

40+
from .interfaces import MixtureOfExperts
4041
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
41-
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
42-
is_pp_missing_parameter)
42+
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
43+
fast_topk, is_pp_missing_parameter)
4344

4445

4546
class Llama4MoE(nn.Module):
@@ -59,7 +60,8 @@ def custom_routing_function(
5960
def __init__(self,
6061
config: Llama4TextConfig,
6162
quant_config: Optional[QuantizationConfig] = None,
62-
prefix: str = ""):
63+
prefix: str = "",
64+
enable_eplb: bool = False):
6365
super().__init__()
6466
self.tp_size = get_tensor_model_parallel_world_size()
6567
self.top_k = config.num_experts_per_tok
@@ -81,7 +83,8 @@ def __init__(self,
8183
reduce_results=False,
8284
renormalize=False,
8385
quant_config=quant_config,
84-
prefix=f"{prefix}.experts")
86+
prefix=f"{prefix}.experts",
87+
enable_eplb=enable_eplb)
8588

8689
self.shared_expert = LlamaMLP(
8790
hidden_size=config.hidden_size,
@@ -251,6 +254,7 @@ def __init__(
251254
cache_config: Optional[CacheConfig] = None,
252255
quant_config: Optional[QuantizationConfig] = None,
253256
prefix: str = "",
257+
enable_eplb: bool = False,
254258
) -> None:
255259
super().__init__()
256260

@@ -281,6 +285,7 @@ def __init__(
281285
config=config,
282286
quant_config=quant_config,
283287
prefix=f"{prefix}.feed_forward",
288+
enable_eplb=enable_eplb,
284289
)
285290
else:
286291
self.feed_forward = LlamaMLP(
@@ -328,10 +333,53 @@ def __init__(self,
328333
prefix: str = "",
329334
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
330335
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
336+
self.enable_eplb = vllm_config.parallel_config.enable_eplb if hasattr(
337+
vllm_config.parallel_config, 'enable_eplb') else False
338+
339+
# Create layers with enable_eplb parameter
340+
self.vllm_config = vllm_config
341+
self.layer_type = layer_type
342+
331343
super().__init__(vllm_config=vllm_config,
332344
prefix=prefix,
333345
layer_type=layer_type)
334346

347+
# Track MoE layers for EPLB
348+
self.moe_layers = []
349+
config = vllm_config.model_config.hf_config
350+
for i, layer in enumerate(self.layers):
351+
if isinstance(layer, layer_type):
352+
is_moe_layer = (config.interleave_moe_layer_step > 0
353+
and (i + 1) % config.interleave_moe_layer_step
354+
== 0)
355+
if is_moe_layer:
356+
self.moe_layers.append(layer)
357+
358+
def make_layers(
359+
self,
360+
num_hidden_layers: int,
361+
layer_type: type[nn.Module],
362+
prefix: str,
363+
) -> list[nn.Module]:
364+
"""Override to pass enable_eplb to decoder layers."""
365+
layers = []
366+
for layer_idx in range(num_hidden_layers):
367+
if isinstance(self.start_layer,
368+
int) and layer_idx < self.start_layer or (
369+
isinstance(self.end_layer, int)
370+
and layer_idx >= self.end_layer):
371+
layers.append(PPMissingLayer())
372+
else:
373+
layer = layer_type(
374+
config=self.config,
375+
cache_config=self.cache_config,
376+
quant_config=self.quant_config,
377+
prefix=f"{prefix}.{layer_idx}",
378+
enable_eplb=self.enable_eplb,
379+
)
380+
layers.append(layer)
381+
return layers
382+
335383
def load_moe_expert_weights(
336384
self,
337385
name: str,
@@ -460,7 +508,7 @@ def load_weights(self, weights: Iterable[tuple[str,
460508
return loaded_params
461509

462510

463-
class Llama4ForCausalLM(LlamaForCausalLM):
511+
class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
464512

465513
packed_modules_mapping = {
466514
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -482,6 +530,87 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
482530
prefix=prefix,
483531
layer_type=Llama4DecoderLayer)
484532

533+
# For MixtureOfExperts protocol
534+
self._num_redundant_experts = 0
535+
536+
@property
537+
def expert_weights(self) -> list[torch.nn.Module]:
538+
"""Get all MoE layers"""
539+
return self.model.moe_layers
540+
541+
@property
542+
def num_moe_layers(self) -> int:
543+
"""Get number of MoE layers"""
544+
return len(self.model.moe_layers)
545+
546+
@property
547+
def num_expert_groups(self) -> int:
548+
"""Get number of expert groups (1 for Llama4)"""
549+
return 1
550+
551+
@property
552+
def num_logical_experts(self) -> int:
553+
"""Get number of logical experts"""
554+
return self.config.num_local_experts
555+
556+
@property
557+
def num_physical_experts(self) -> int:
558+
"""Get number of physical experts (includes redundant)"""
559+
return self.config.num_local_experts + self._num_redundant_experts
560+
561+
@property
562+
def num_local_physical_experts(self) -> int:
563+
"""Get number of local physical experts"""
564+
if hasattr(self.model, 'moe_layers') and len(
565+
self.model.moe_layers) > 0:
566+
moe_layer = self.model.moe_layers[0]
567+
if hasattr(moe_layer.feed_forward, 'experts'):
568+
return moe_layer.feed_forward.experts.local_num_experts
569+
return self.config.num_local_experts
570+
571+
@property
572+
def num_routed_experts(self) -> int:
573+
"""Get number of routed experts (excludes shared experts)"""
574+
return self.config.num_local_experts
575+
576+
@property
577+
def num_shared_experts(self) -> int:
578+
"""Get number of shared experts"""
579+
# Llama4 has 1 shared expert per MoE layer
580+
return 1 if self.num_moe_layers > 0 else 0
581+
582+
@property
583+
def num_redundant_experts(self) -> int:
584+
"""Get number of redundant experts"""
585+
return self._num_redundant_experts
586+
587+
def set_eplb_state(
588+
self,
589+
moe_layer_indices: list[int],
590+
expert_load_view: torch.Tensor,
591+
logical_to_physical_map: torch.Tensor,
592+
logical_replica_count: torch.Tensor,
593+
) -> None:
594+
"""Set EPLB state for MoE layers"""
595+
for i, moe_layer_idx in enumerate(moe_layer_indices):
596+
moe_layer = self.model.moe_layers[i]
597+
if hasattr(moe_layer.feed_forward, 'experts'):
598+
moe_layer.feed_forward.experts.set_eplb_state(
599+
moe_layer_idx=i,
600+
expert_load_view=expert_load_view,
601+
logical_to_physical_map=logical_to_physical_map,
602+
logical_replica_count=logical_replica_count,
603+
)
604+
605+
def get_expert_weights(self) -> list[list[torch.Tensor]]:
606+
"""Get expert weights from all MoE layers"""
607+
expert_weights = []
608+
for moe_layer in self.model.moe_layers:
609+
if hasattr(moe_layer.feed_forward, 'experts'):
610+
expert_weights.append(
611+
list(moe_layer.feed_forward.experts.get_expert_weights()))
612+
return expert_weights
613+
485614
def _init_model(self,
486615
vllm_config: VllmConfig,
487616
prefix: str = "",

0 commit comments

Comments
 (0)