Skip to content

Commit 13ede83

Browse files
committed
feat: enable llama4 EPLB
1 parent 7ffb754 commit 13ede83

File tree

1 file changed

+63
-44
lines changed

1 file changed

+63
-44
lines changed

vllm/model_executor/models/llama4.py

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# limitations under the License.
1919
"""Inference-only LLaMA model compatible with HuggingFace weights."""
2020
from collections.abc import Iterable
21-
from typing import Any, Optional
21+
from typing import Any, Optional, Union
2222

2323
import torch
2424
from torch import nn
@@ -335,6 +335,8 @@ def __init__(self,
335335
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
336336
self.enable_eplb = vllm_config.parallel_config.enable_eplb if hasattr(
337337
vllm_config.parallel_config, 'enable_eplb') else False
338+
self._num_redundant_experts = getattr(
339+
vllm_config.parallel_config, 'num_redundant_experts', 0)
338340

339341
# Create layers with enable_eplb parameter
340342
self.vllm_config = vllm_config
@@ -390,53 +392,69 @@ def load_moe_expert_weights(
390392
fused: bool = True,
391393
) -> bool:
392394
expert_param_loaded = False
395+
is_expert_weight = False
396+
loaded_weight_list = [loaded_weight]
393397
if "experts.gate_up_proj" in name:
394-
loaded_weight = loaded_weight.chunk(2, dim=-1)
398+
loaded_weight_list = list(loaded_weight.chunk(2, dim=-1))
399+
395400
for (param_name, weight_name, expert_id,
396401
shard_id) in expert_params_mapping:
397-
new_loaded_weight = loaded_weight
402+
if weight_name not in name:
403+
continue
404+
405+
is_expert_weight = True
406+
new_loaded_weight = loaded_weight_list[0]
407+
398408
if fused:
399409
e_str, _, proj_str, _ = weight_name.split('.')
400410
weight_name = f"{e_str}.{proj_str}"
401411
param_name = f"{param_name}weight"
402-
if weight_name not in name:
403-
continue
412+
404413
full_param_name = name.replace(weight_name, param_name)
414+
405415
# Skip layers on other devices.
406-
if is_pp_missing_parameter(name, self):
416+
if is_pp_missing_parameter(full_param_name, self):
407417
continue
408-
if ((name.endswith(".bias") or name.endswith("_bias"))
409-
and name not in params_dict):
418+
if ((full_param_name.endswith(".bias") or full_param_name.endswith("_bias"))
419+
and full_param_name not in params_dict):
410420
continue
421+
411422
param = params_dict[full_param_name]
412-
weight_loader = param.weight_loader
423+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
424+
413425
if fused:
414426
if "w13" in full_param_name:
415427
shard_idx = 0 if shard_id == "w1" else 1
416-
new_loaded_weight = new_loaded_weight[shard_idx]
428+
new_loaded_weight = loaded_weight_list[shard_idx]
417429
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
418-
layer_idx = extract_layer_index(name)
419-
# EP mapping
420-
expert_map = self.layers[
421-
layer_idx].feed_forward.experts.expert_map
422-
if expert_map is not None:
423-
local_expert_indices = (expert_map != -1) \
424-
.nonzero() \
425-
.flatten() \
426-
.to(new_loaded_weight.device)
427-
new_loaded_weight = new_loaded_weight[local_expert_indices]
428-
expert_id = local_expert_indices[0].item()
429-
else:
430-
# TODO: add EP support for non fused weights
431-
pass
432-
weight_loader(param,
433-
new_loaded_weight,
434-
full_param_name,
435-
shard_id=shard_id,
436-
expert_id=expert_id)
437-
438-
loaded_params.add(full_param_name)
439-
expert_param_loaded = True
430+
431+
# Use return_success to check if weight loading succeeded
432+
if hasattr(weight_loader, '__call__'):
433+
# Check if weight_loader supports return_success parameter
434+
import inspect
435+
sig = inspect.signature(weight_loader)
436+
if 'return_success' in sig.parameters:
437+
success = weight_loader(param,
438+
new_loaded_weight,
439+
full_param_name,
440+
shard_id=shard_id,
441+
expert_id=expert_id,
442+
return_success=True)
443+
if success:
444+
loaded_params.add(full_param_name)
445+
expert_param_loaded = True
446+
break
447+
else:
448+
# Fallback for weight loaders without return_success
449+
weight_loader(param,
450+
new_loaded_weight,
451+
full_param_name,
452+
shard_id=shard_id,
453+
expert_id=expert_id)
454+
loaded_params.add(full_param_name)
455+
expert_param_loaded = True
456+
break
457+
440458
return expert_param_loaded
441459

442460
def load_weights(self, weights: Iterable[tuple[str,
@@ -450,16 +468,19 @@ def load_weights(self, weights: Iterable[tuple[str,
450468
(".gate_up_proj", ".up_proj", 1),
451469
]
452470
fused_experts_params = False
471+
# Pass num_redundant_experts for EPLB support
453472
expert_params_mapping = FusedMoE.make_expert_params_mapping(
454473
ckpt_gate_proj_name="gate_proj",
455474
ckpt_down_proj_name="down_proj",
456475
ckpt_up_proj_name="up_proj",
457-
num_experts=self.num_experts)
476+
num_experts=self.num_experts,
477+
num_redundant_experts=self._num_redundant_experts)
458478
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
459479
ckpt_gate_proj_name="gate_up_proj",
460480
ckpt_down_proj_name="down_proj",
461481
ckpt_up_proj_name="gate_up_proj",
462-
num_experts=1)
482+
num_experts=1,
483+
num_redundant_experts=0)
463484
params_dict = dict(self.named_parameters())
464485
loaded_params: set[str] = set()
465486
for name, loaded_weight in weights:
@@ -531,7 +552,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
531552
layer_type=Llama4DecoderLayer)
532553

533554
# For MixtureOfExperts protocol
534-
self._num_redundant_experts = 0
555+
parallel_config = vllm_config.parallel_config
556+
self._num_redundant_experts = getattr(parallel_config, 'num_redundant_experts', 0)
557+
self.expert_weights = []
535558

536559
# Track FusedMoE layers for EPLB
537560
self.moe_layers: list[FusedMoE] = []
@@ -545,11 +568,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
545568
layer.feed_forward, 'experts'):
546569
self.moe_layers.append(layer.feed_forward.experts)
547570

548-
@property
549-
def expert_weights(self) -> list[torch.nn.Module]:
550-
"""Get all MoE layers"""
551-
return self.moe_layers
552-
553571
@property
554572
def num_moe_layers(self) -> int:
555573
"""Get number of MoE layers"""
@@ -595,15 +613,16 @@ def num_redundant_experts(self) -> int:
595613

596614
def set_eplb_state(
597615
self,
598-
moe_layer_indices: list[int],
599616
expert_load_view: torch.Tensor,
600617
logical_to_physical_map: torch.Tensor,
601618
logical_replica_count: torch.Tensor,
602619
) -> None:
603620
"""Set EPLB state for MoE layers"""
604-
for i, moe_layer_idx in enumerate(moe_layer_indices):
605-
self.moe_layers[i].set_eplb_state(
606-
moe_layer_idx=i,
621+
for layer_idx, layer in enumerate(self.moe_layers):
622+
# Register the expert weights.
623+
self.expert_weights.append(layer.get_expert_weights())
624+
layer.set_eplb_state(
625+
moe_layer_idx=layer_idx,
607626
expert_load_view=expert_load_view,
608627
logical_to_physical_map=logical_to_physical_map,
609628
logical_replica_count=logical_replica_count,

0 commit comments

Comments
 (0)