@@ -79,7 +79,8 @@ def __init__(self,
79
79
80
80
# Calculate expert distribution
81
81
self .n_logical_experts = config .num_local_experts
82
- self .n_redundant_experts = parallel_config .num_redundant_experts
82
+ # Only use redundant experts if EPLB is enabled
83
+ self .n_redundant_experts = parallel_config .num_redundant_experts if self .enable_eplb else 0
83
84
self .n_physical_experts = self .n_logical_experts + self .n_redundant_experts
84
85
self .n_local_physical_experts = self .n_physical_experts // self .ep_size
85
86
@@ -276,7 +277,6 @@ def __init__(
276
277
cache_config : Optional [CacheConfig ] = None ,
277
278
quant_config : Optional [QuantizationConfig ] = None ,
278
279
prefix : str = "" ,
279
- enable_eplb : bool = False ,
280
280
) -> None :
281
281
super ().__init__ ()
282
282
@@ -285,6 +285,11 @@ def __init__(
285
285
rope_theta = config .rope_theta
286
286
rope_scaling = config .rope_scaling
287
287
max_position_embeddings = config .max_position_embeddings
288
+
289
+ # Get enable_eplb from current vllm config
290
+ vllm_config = get_current_vllm_config ()
291
+ enable_eplb = vllm_config .parallel_config .enable_eplb if hasattr (
292
+ vllm_config .parallel_config , 'enable_eplb' ) else False
288
293
289
294
self .self_attn = Llama4Attention (
290
295
config = config ,
@@ -386,7 +391,7 @@ def make_layers(
386
391
layer_type : type [nn .Module ],
387
392
prefix : str ,
388
393
) -> list [nn .Module ]:
389
- """Override to pass enable_eplb to decoder layers."""
394
+ """Create decoder layers."""
390
395
layers = []
391
396
for layer_idx in range (num_hidden_layers ):
392
397
if isinstance (self .start_layer ,
@@ -400,7 +405,6 @@ def make_layers(
400
405
cache_config = self .cache_config ,
401
406
quant_config = self .quant_config ,
402
407
prefix = f"{ prefix } .{ layer_idx } " ,
403
- enable_eplb = self .enable_eplb ,
404
408
)
405
409
layers .append (layer )
406
410
return layers
0 commit comments