Skip to content

Commit 7ffb754

Browse files
committed
feat: enable llama4 EPLB
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
1 parent 2f46c26 commit 7ffb754

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

vllm/model_executor/models/llama4.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -533,15 +533,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
533533
# For MixtureOfExperts protocol
534534
self._num_redundant_experts = 0
535535

536+
# Track FusedMoE layers for EPLB
537+
self.moe_layers: list[FusedMoE] = []
538+
config = vllm_config.model_config.hf_config
539+
for i, layer in enumerate(self.model.layers):
540+
if isinstance(layer, Llama4DecoderLayer):
541+
is_moe_layer = (config.interleave_moe_layer_step > 0
542+
and (i + 1) % config.interleave_moe_layer_step
543+
== 0)
544+
if is_moe_layer and hasattr(layer, 'feed_forward') and hasattr(
545+
layer.feed_forward, 'experts'):
546+
self.moe_layers.append(layer.feed_forward.experts)
547+
536548
@property
537549
def expert_weights(self) -> list[torch.nn.Module]:
538550
"""Get all MoE layers"""
539-
return self.model.moe_layers
551+
return self.moe_layers
540552

541553
@property
542554
def num_moe_layers(self) -> int:
543555
"""Get number of MoE layers"""
544-
return len(self.model.moe_layers)
556+
return len(self.moe_layers)
545557

546558
@property
547559
def num_expert_groups(self) -> int:
@@ -561,11 +573,8 @@ def num_physical_experts(self) -> int:
561573
@property
562574
def num_local_physical_experts(self) -> int:
563575
"""Get number of local physical experts"""
564-
if hasattr(self.model, 'moe_layers') and len(
565-
self.model.moe_layers) > 0:
566-
moe_layer = self.model.moe_layers[0]
567-
if hasattr(moe_layer.feed_forward, 'experts'):
568-
return moe_layer.feed_forward.experts.local_num_experts
576+
if len(self.moe_layers) > 0:
577+
return self.moe_layers[0].local_num_experts
569578
return self.config.num_local_experts
570579

571580
@property
@@ -593,22 +602,18 @@ def set_eplb_state(
593602
) -> None:
594603
"""Set EPLB state for MoE layers"""
595604
for i, moe_layer_idx in enumerate(moe_layer_indices):
596-
moe_layer = self.model.moe_layers[i]
597-
if hasattr(moe_layer.feed_forward, 'experts'):
598-
moe_layer.feed_forward.experts.set_eplb_state(
599-
moe_layer_idx=i,
600-
expert_load_view=expert_load_view,
601-
logical_to_physical_map=logical_to_physical_map,
602-
logical_replica_count=logical_replica_count,
603-
)
605+
self.moe_layers[i].set_eplb_state(
606+
moe_layer_idx=i,
607+
expert_load_view=expert_load_view,
608+
logical_to_physical_map=logical_to_physical_map,
609+
logical_replica_count=logical_replica_count,
610+
)
604611

605612
def get_expert_weights(self) -> list[list[torch.Tensor]]:
606613
"""Get expert weights from all MoE layers"""
607614
expert_weights = []
608-
for moe_layer in self.model.moe_layers:
609-
if hasattr(moe_layer.feed_forward, 'experts'):
610-
expert_weights.append(
611-
list(moe_layer.feed_forward.experts.get_expert_weights()))
615+
for moe_layer in self.moe_layers:
616+
expert_weights.append(list(moe_layer.get_expert_weights()))
612617
return expert_weights
613618

614619
def _init_model(self,

0 commit comments

Comments
 (0)