Skip to content

Commit 5891bdd

Browse files
committed
feat: enable llama4 EPLB
1 parent cea35dd commit 5891bdd

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

vllm/model_executor/models/llama4.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def __init__(self,
7979

8080
# Calculate expert distribution
8181
self.n_logical_experts = config.num_local_experts
82-
self.n_redundant_experts = parallel_config.num_redundant_experts
82+
# Only use redundant experts if EPLB is enabled
83+
self.n_redundant_experts = parallel_config.num_redundant_experts if self.enable_eplb else 0
8384
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
8485
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
8586

@@ -276,7 +277,6 @@ def __init__(
276277
cache_config: Optional[CacheConfig] = None,
277278
quant_config: Optional[QuantizationConfig] = None,
278279
prefix: str = "",
279-
enable_eplb: bool = False,
280280
) -> None:
281281
super().__init__()
282282

@@ -285,6 +285,11 @@ def __init__(
285285
rope_theta = config.rope_theta
286286
rope_scaling = config.rope_scaling
287287
max_position_embeddings = config.max_position_embeddings
288+
289+
# Get enable_eplb from current vllm config
290+
vllm_config = get_current_vllm_config()
291+
enable_eplb = vllm_config.parallel_config.enable_eplb if hasattr(
292+
vllm_config.parallel_config, 'enable_eplb') else False
288293

289294
self.self_attn = Llama4Attention(
290295
config=config,
@@ -386,7 +391,7 @@ def make_layers(
386391
layer_type: type[nn.Module],
387392
prefix: str,
388393
) -> list[nn.Module]:
389-
"""Override to pass enable_eplb to decoder layers."""
394+
"""Create decoder layers."""
390395
layers = []
391396
for layer_idx in range(num_hidden_layers):
392397
if isinstance(self.start_layer,
@@ -400,7 +405,6 @@ def make_layers(
400405
cache_config=self.cache_config,
401406
quant_config=self.quant_config,
402407
prefix=f"{prefix}.{layer_idx}",
403-
enable_eplb=self.enable_eplb,
404408
)
405409
layers.append(layer)
406410
return layers

0 commit comments

Comments
 (0)