Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
fcc3406
[EPLB]Add eplb to main.
Sep 16, 2025
695de79
[EPLB]Fix ci prolem.
Sep 16, 2025
134f4ec
[EPLB]Fix ci.
Sep 16, 2025
01545c1
[EPLB]Forma.
Sep 16, 2025
3fd79c3
[EPLB]Fix ci.
Sep 16, 2025
5298caf
[EPLB]Fix ci.
Sep 16, 2025
9a6fc7c
[EPLB]Fix ci.
Sep 16, 2025
981bb38
[EPLB]Add featureplb feature guide.e guide
Sep 16, 2025
7409960
Merge remote-tracking branch 'upstream_gitee/main' into main_eplb_0916
Sep 16, 2025
20b3b85
[EPLB]Fix doc ci.
Sep 16, 2025
41c5ad8
[EPLB]Fix eplbReformat eplb doc.
Sep 16, 2025
246eb2c
[EPLB]Add comment to eplb initialization
Sep 16, 2025
ab55a50
[EPLB]Add dynamic eplb to fused_experts method.
Sep 16, 2025
2037241
[EPLB]Add type
Sep 16, 2025
76a134f
[EPLB]Fix ci.
Sep 16, 2025
2c3bd98
[EPLB]Fix ci.
Sep 16, 2025
d215cef
[EPLB]Add type check for local_count.
Sep 16, 2025
5126bdc
[EPLB]Fix ut.
Sep 16, 2025
46ed5bd
[EPLB]Fix dense layers num.
Sep 16, 2025
345fd09
[EPLB]Check log2phy_map is not empty.
Sep 16, 2025
7235c4c
[EPLB]Fix ci.
Sep 16, 2025
531df4b
[EPLB]Fix ci.
Sep 16, 2025
64cd7bd
[EPLB]Fix hidden states.
Sep 16, 2025
c6bdc59
[EPLB]Fix ci.
Sep 16, 2025
0009b6a
[EPLB]Fix ci.
Sep 16, 2025
01eb40a
[EPLB]Add try catch when get log2phy map.
Sep 16, 2025
7e5450a
[EPLB]Fix ci.
Sep 16, 2025
a58ccbf
[EPLB]Fix ci.
Sep 16, 2025
505b317
[EPLB]Fix ci.
Sep 16, 2025
a75aeef
[EPLB]Add value check for gourp list type.
Sep 16, 2025
d0502ef
D2D transfer task
mercykid Sep 19, 2025
b9c0b01
Merge pull request #130 from Mercykid-bash/D2D_transfer
offline893 Sep 19, 2025
627bbdd
[BugFiModify eplb feature guide.
Sep 22, 2025
cce88a5
[BugFix]Fix expert map.
Sep 22, 2025
65e348f
Merge remote-tracking branch 'upstream_gitee/main' into main_eplb_0916
Sep 22, 2025
843398c
Merge remote-tracking branch 'upstream_gitee/main' into main_eplb_0916
offline0806 Sep 22, 2025
1164840
[Fix ci.]
offline0806 Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/user_guide/configuration/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ The following table lists the additional configuration options available in vLLM
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. |
| `multistream_overlap_shared_expert`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on moe models with shared experts. |
| `dynamic_eplb` | bool | `False` | Whether to enable dynamic eplb |
|`num_iterations_eplb_update`| int | `400` | Forward iterations when eplb would begin |
|`gate_eplb`| bool | `False` | Whether to enale eplb only once. |
|`num_wait_worker_iterations`| int | `30` | The forward iterations when eplb worker will finish cpu task. In our test default value 30 would cover most cases. |
|`expert_map_record_path`| str | `None` | When dynamic eplb is completed, save the current expert load heatmap to the specified path. |
|`init_redundancy_expert`| int | `0` |Specify redundant experts during initialization.|

The details of each config option are as follows:

Expand Down
43 changes: 23 additions & 20 deletions docs/source/user_guide/feature_guide/eplb_swift_balancer.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Swift Balancer

## Overview
Experts rebalancing of MoE models for LLM serving is a mandatory option.Changing experts dynamically would have a negative impact on TTFT and TPOT while stop-the-world.
Experts balancing of MoE models for LLM serving is a mandatory option.Changing experts dynamically would have a negative impact on TTFT and TPOT while stop-the-world.
Asynchronously expert load balancing would be a better choice.
We have launched SwiftBalancer to support dynamic experts load balancing with Zero-overhead experts movement.

Expand All @@ -15,31 +15,34 @@ The overall workflow involves:
2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume.
3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker.
4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time.
5. Lanch ibatch_send_recv in async_stream before forward.
6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights.
5. Launch ibatch_send_recv in async_stream before forward.
6. After forward, wait for the ibatch_send_recv finish, then do update expert map and expert weights.

In our profiling shows experts transforming is hidden in the bubble between forward iterations. Cpu time cost of eplb algo. and other python operator such as log2phy
would be hidden by eplb worker process too.

## Config Params

Currently swift balancer optimize 5ms TPOT with ep size 64 while cost less than 2ms for every layer expert movement.

We add new parameters for eplb:
"dynamic_eplb":true --- enable dynamic eplb
"num_iterations_eplb_update": 400 -- forward iterations when eplb would begin
"gate_eplb":true -- eplb would update only once, false by default.
"num_wait_worker_iterations":30 -- forward iterations when eplb worker will finish cpu task. In our test default value 30 would cover most cases.
"expert_map_record_path" -- When dynamic eplb is completed, save the current expert load heatmap to the specified path.
"init_redundancy_expert" -- Specify redundant experts during initialization.

## Examples
### Dynamic eplb
Enable dynamic eplb and specify the trigger rounds.
--additional-config '{ "dynamic_eplb":true,"num_iterations_eplb_update":400, "gate_eplb":true, "num_wait_worker_iterations":30}'
### Record expert map for static eplb
Specify the path for the static eplb initialization file.
--additional-config '{ "expert_map_record_path": "/xx/xx.json", "init_redundancy_expert": 16, dynamic_eplb":true,"num_iterations_eplb_update":400, "gate_eplb":true, "num_wait_worker_iterations":30}'
```shell
vllm serve Qwen/Qwen3-235B-A22 \
--tensor-parallel-size 16 \
--enable-expert-parallel \
--additional-config '{ "dynamic_eplb":true,"num_iterations_eplb_update":400, "gate_eplb":true, "num_wait_worker_iterations":30}'
```
### Static eplb
If expert map has been recorded, enable static eplb with expert map path.
--additional-config '{ "expert_map_path": "/xx/xx.json"}'
1. Specify the path for the static eplb initialization file.
```shell
vllm serve Qwen/Qwen3-235B-A22 \
--tensor-parallel-size 16 \
--enable-expert-parallel \
--additional-config '{ "expert_map_record_path": "/path/to/eplb.json", "init_redundancy_expert": 16, dynamic_eplb":true,"num_iterations_eplb_update":400, "gate_eplb":true, "num_wait_worker_iterations":30}'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The JSON string in this example is invalid. There is a missing double quote " before dynamic_eplb. This will cause a parsing error for users who copy and paste this command.

Suggested change
--additional-config '{ "expert_map_record_path": "/path/to/eplb.json", "init_redundancy_expert": 16, dynamic_eplb":true,"num_iterations_eplb_update":400, "gate_eplb":true, "num_wait_worker_iterations":30}'
--additional-config '{ "expert_map_record_path": "/path/to/eplb.json", "init_redundancy_expert": 16, "dynamic_eplb":true,"num_iterations_eplb_update":400, "gate_eplb":true, "num_wait_worker_iterations":30}'

```
2. If expert map has been recorded, enable static eplb with expert map path.
```shell
vllm serve Qwen/Qwen3-235B-A22 \
--tensor-parallel-size 16 \
--enable-expert-parallel \
--additional-config '{ "expert_map_path": "/path/to/eplb.json"}'
```
86 changes: 51 additions & 35 deletions vllm_ascend/eplb/core/eplb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils.
import random

import torch
from vllm.logger import logger


def determine_default_expert_map(global_expert_num, world_size, rank_id,
Expand Down Expand Up @@ -56,40 +53,59 @@ def determine_default_expert_map(global_expert_num, world_size, rank_id,


def generate_log2phy_map(expert_map):
"""
Generate a log-to-physical map for experts in a fully vectorized manner.
Args:
expert_map: Tensor of shape [num_ranks, num_global_expert], with -1 indicating
rank does not hold the expert.
Returns:
log2phy_map: Tensor of same shape, mapping logical experts to physical IDs.
"""
num_ranks, _ = expert_map.shape
num_local_experts = expert_map.max() + 1
device = expert_map.device

# Step 1: linear mapping based on rank
log2phy_map = expert_map.clone()
num_ranks, num_global_expert = log2phy_map.shape

row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks, \
num_global_expert) * num_local_experts
log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1]

for idx in range(num_global_expert):
positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0]
negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0]
num_rank_holding_expert = positive_rank_idx.size(0)

if num_rank_holding_expert == 0:
log2phy_map[:, idx] = torch.full((num_ranks, ),
0,
dtype=log2phy_map.dtype)

if num_rank_holding_expert == 1:
log2phy_map[negative_rank_idx, idx] = torch.full(
(num_ranks - 1, ),
log2phy_map[positive_rank_idx, idx].item(),
dtype=log2phy_map.dtype)
else:
try:
random_list = [
random.choice(log2phy_map[positive_rank_idx, idx])
for _ in range(num_ranks - num_rank_holding_expert)
]
log2phy_map[negative_rank_idx,
idx] = torch.tensor(random_list,
dtype=log2phy_map.dtype)
except Exception as e:
logger.error(f"Fail to get log2phy_map: {str(e)}")
row_indices = torch.arange(num_ranks, device=device).view(
-1, 1) * num_local_experts
mask = log2phy_map != -1
# broadcast addition
log2phy_map = log2phy_map + row_indices * mask.long()

# Step 2: find positive/negative positions
positive_mask = log2phy_map != -1
negative_mask = ~positive_mask

# Count number of ranks holding each global expert
num_positive_per_col = positive_mask.sum(dim=0) # [num_global_expert]

# Step 3: handle columns with only one rank holding the expert
single_pos_mask = num_positive_per_col == 1
if single_pos_mask.any():
# get row indices for the positive element in these columns
# pos_idx = torch.nonzero(positive_mask[:, single_pos_mask], as_tuple=True)
# broadcast to fill negative positions
for col_idx in torch.nonzero(single_pos_mask, as_tuple=True)[0]:
pos_row = torch.nonzero(positive_mask[:, col_idx])[0]
neg_rows = torch.nonzero(negative_mask[:, col_idx])[:, 0]
log2phy_map[neg_rows, col_idx] = log2phy_map[pos_row, col_idx]

# Step 4: handle columns with multiple ranks holding the expert
multi_pos_mask = num_positive_per_col > 1
if multi_pos_mask.any():
for col_idx in torch.nonzero(multi_pos_mask, as_tuple=True)[0]:
pos_rows = torch.nonzero(positive_mask[:, col_idx])[:, 0]
neg_rows = torch.nonzero(negative_mask[:, col_idx])[:, 0]
if len(neg_rows) > 0:
# random assignment from available positive ranks
rand_idx = torch.randint(0,
len(pos_rows), (len(neg_rows), ),
device=device)
log2phy_map[neg_rows,
col_idx] = log2phy_map[pos_rows[rand_idx], col_idx]

return log2phy_map

Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/eplb/core/eplb_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def compose_expert_update_info_greedy(self, updated_expert_maps,
yield (expert_send_info_this_layer,
expert_recv_info_this_layer,
updated_expert_maps_this_layer, layer_id)
continue

# Parse expert_ids each rank needs to receive from other ranks
dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def forward_impl(self, hidden_states: torch.Tensor,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num)
if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states

Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def apply(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
expert_map=expert_map,
dynamic_eplb=self.dynamic_eplb)
dynamic_eplb=self.dynamic_eplb,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num)

# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
Expand Down
Loading