Skip to content

Commit 779d855

Browse files
committed
Merge branch 'llama4-eplb-clean' into enable-llama4-eplb
2 parents 8632e83 + 979b2bc commit 779d855

File tree

3 files changed

+232
-32
lines changed

3 files changed

+232
-32
lines changed

test_llama4_eplb.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm import LLM, SamplingParams
5+
6+
# Sample prompts.
7+
prompts = [
8+
"Hello, my name is",
9+
"The president of the United States is",
10+
"The capital of France is",
11+
"The future of AI is",
12+
]
13+
# Create a sampling params object.
14+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
15+
16+
17+
def main():
18+
# Create an LLM with EPLB parameters.
19+
llm = LLM(
20+
model="/fp8-llama/llama4scout-fp8/",
21+
tensor_parallel_size=8,
22+
max_model_len=2048,
23+
enable_expert_parallel=True,
24+
enable_eplb=True,
25+
num_redundant_experts=16,
26+
eplb_window_size=1000,
27+
eplb_step_interval=3000,
28+
trust_remote_code=True,
29+
enforce_eager=True,
30+
)
31+
# Generate texts from the prompts.
32+
# The output is a list of RequestOutput objects
33+
# that contain the prompt, generated text, and other information.
34+
outputs = llm.generate(prompts, sampling_params)
35+
# Print the outputs.
36+
print("\nGenerated Outputs:\n" + "-" * 60)
37+
for output in outputs:
38+
prompt = output.prompt
39+
generated_text = output.outputs[0].text
40+
print(f"Prompt: {prompt!r}")
41+
print(f"Output: {generated_text!r}")
42+
print("-" * 60)
43+
44+
45+
if __name__ == "__main__":
46+
main()

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,10 @@ def apply(
362362
logical_replica_count: Optional[torch.Tensor] = None,
363363
) -> torch.Tensor:
364364
if enable_eplb:
365-
raise NotImplementedError(
366-
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
365+
assert expert_load_view is not None
366+
assert logical_to_physical_map is not None
367+
assert logical_replica_count is not None
368+
assert isinstance(layer, FusedMoE)
367369

368370
return self.forward(
369371
x=x,
@@ -380,7 +382,12 @@ def apply(
380382
scoring_func=scoring_func,
381383
e_score_correction_bias=e_score_correction_bias,
382384
activation=activation,
383-
apply_router_weight_on_input=apply_router_weight_on_input)
385+
apply_router_weight_on_input=apply_router_weight_on_input,
386+
enable_eplb=enable_eplb,
387+
expert_load_view=expert_load_view,
388+
logical_to_physical_map=logical_to_physical_map,
389+
logical_replica_count=logical_replica_count,
390+
)
384391

385392
def forward_cuda(
386393
self,
@@ -399,6 +406,10 @@ def forward_cuda(
399406
e_score_correction_bias: Optional[torch.Tensor] = None,
400407
apply_router_weight_on_input: bool = False,
401408
activation: str = "silu",
409+
enable_eplb: bool = False,
410+
expert_load_view: Optional[torch.Tensor] = None,
411+
logical_to_physical_map: Optional[torch.Tensor] = None,
412+
logical_replica_count: Optional[torch.Tensor] = None,
402413
) -> torch.Tensor:
403414

404415
topk_weights, topk_ids = FusedMoE.select_experts(
@@ -412,7 +423,11 @@ def forward_cuda(
412423
custom_routing_function=custom_routing_function,
413424
scoring_func=scoring_func,
414425
e_score_correction_bias=e_score_correction_bias,
415-
indices_type=self.topk_indices_dtype)
426+
indices_type=self.topk_indices_dtype,
427+
enable_eplb=enable_eplb,
428+
expert_load_view=expert_load_view,
429+
logical_to_physical_map=logical_to_physical_map,
430+
logical_replica_count=logical_replica_count)
416431

417432
if self.rocm_aiter_moe_enabled:
418433
return self.rocm_aiter_fused_experts(
@@ -753,7 +768,8 @@ def __init__(
753768
if self.enable_eplb:
754769
from vllm.model_executor.layers.quantization.fp8 import (
755770
Fp8MoEMethod)
756-
if not isinstance(quant_method, Fp8MoEMethod):
771+
if not isinstance(quant_method, Fp8MoEMethod) and not isinstance(
772+
quant_method, UnquantizedFusedMoEMethod):
757773
# TODO: Add support for additional quantization methods.
758774
# The implementation for other quantization methods does not
759775
# contain essential differences, but the current quant API

0 commit comments

Comments
 (0)