Skip to content

Commit 8aadcb7

Browse files
shiyuan680yangcheng
andauthored
[0.9.1]eplb support qwen3-moe (vllm-project#2000)
### What this PR does / why we need it? this pr is eplb add support for qwen3-moe test in qwen3-moe-235b w8a8 tp1dp16ep16 has 3% optimization Signed-off-by: yangcheng <yangcheng104@huawei.com> Co-authored-by: yangcheng <yangcheng104@huawei.com>
1 parent 2f1dbe5 commit 8aadcb7

File tree

7 files changed

+159
-68
lines changed

7 files changed

+159
-68
lines changed

examples/eplb_generate_map.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,33 @@
11
import argparse
22
import json
3+
import random
34

45
import numpy as np
56

67

8+
def add_unique_number_with_retry(existing_numbers,
9+
valid_range,
10+
max_attempts=100):
11+
'''
12+
generate an unique number not in existing_numbers
13+
Args:
14+
existing_numbers:
15+
valid_range:
16+
max_attempts:
17+
18+
Returns:
19+
20+
'''
21+
existing_set = set(existing_numbers)
22+
min_val, max_val = valid_range
23+
for _ in range(max_attempts):
24+
candidate = random.randint(min_val, max_val)
25+
if candidate not in existing_set:
26+
return candidate
27+
28+
raise ValueError('No unique number found')
29+
30+
731
def split_and_insert(n, k, m):
832
'''
933
n: expert num
@@ -13,13 +37,12 @@ def split_and_insert(n, k, m):
1337

1438
A = np.arange(n)
1539

16-
B = np.random.choice(n, size=m, replace=False)
17-
1840
groups = np.array_split(A, k)
1941

2042
for j in range(m // k):
2143
for i in range(k):
22-
groups[i] = np.append(groups[i], B[i + j * k])
44+
candidate = add_unique_number_with_retry(groups[i], (0, n - 1))
45+
groups[i] = np.append(groups[i], candidate)
2346
return np.concatenate(groups)
2447

2548

tests/singlecard/ops/test_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_fused_experts(
9898
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
9999
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
100100
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
101-
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
101+
torch.testing.assert_close(output[0], torch_output, atol=4e-2, rtol=1)
102102
torch.npu.empty_cache()
103103

104104

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,22 @@ def __init__(self, model, **args):
3232
self.rank_id = dist.get_rank()
3333
self.world_size = dist.get_world_size()
3434
self.param_dict = dict(self.model.named_parameters())
35-
self.num_dense_layers = self.model.config.first_k_dense_replace
35+
if self.model.config.model_type == "qwen3_moe":
36+
self.num_dense_layers = 0
37+
self.global_expert_num = self.model.config.num_experts
38+
else:
39+
self.num_dense_layers = self.model.config.first_k_dense_replace
40+
self.global_expert_num = self.model.config.n_routed_experts
3641
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
37-
self.global_expert_num = self.model.config.n_routed_experts
3842

39-
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here
40-
self.expert_weight_names = [
41-
"w13_weight", "w2_weight", "w13_weight_scale", "w13_weight_offset",
42-
"w2_weight_scale", "w2_weight_offset"
43-
]
43+
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
44+
if self.model.quant_config is not None:
45+
self.expert_weight_names = [
46+
"w13_weight", "w2_weight", "w13_weight_scale",
47+
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
48+
]
49+
else:
50+
self.expert_weight_names = ["w13_weight", "w2_weight"]
4451

4552
self.expert_map_per_layer = dict(
4653
) # reference to expert map on device for expert map update
@@ -127,8 +134,12 @@ def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
127134
expert_map_all = self.determine_expert_map_all()
128135

129136
for layer_idx in range(num_moe_layers):
130-
self.expert_map_per_layer_cpu[layer_idx + 3] = \
131-
expert_map_all[layer_idx][self.rank_id]
137+
if self.model.config.model_type == "qwen3_moe":
138+
self.expert_map_per_layer_cpu[layer_idx] = \
139+
expert_map_all[layer_idx][self.rank_id]
140+
else:
141+
self.expert_map_per_layer_cpu[layer_idx + 3] = \
142+
expert_map_all[layer_idx][self.rank_id]
132143
return expert_map_all
133144

134145
def _expert_file_to_tensor(self, expert_map_path: str):

vllm_ascend/eplb/utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# This file is a part of the vllm-ascend project.
16+
#
17+
import types
18+
19+
import torch
20+
21+
22+
def get_expert_map(self, layer_id):
23+
return self.model.layers[layer_id].mlp.experts.get_map()
24+
25+
26+
def get_log2phy_map(self, layer_id):
27+
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
28+
29+
30+
def get_all_expert_map(self, num_moe_layers):
31+
all_loads = []
32+
for layer_id in range(num_moe_layers):
33+
load_tensor = self.get_expert_map(layer_id) # (num_experts_per_layer,)
34+
all_loads.append(load_tensor)
35+
36+
return torch.stack(all_loads, dim=0)
37+
38+
39+
def get_all_moe_loads(self):
40+
all_moe_loads = torch.stack(
41+
[self.model.layers[layer_id].mlp.experts.moe_load \
42+
for layer_id in range(self.num_moe_layers)],
43+
dim=0
44+
)
45+
return all_moe_loads
46+
47+
48+
def clear_all_moe_loads(self):
49+
for layer_id in range(self.num_moe_layers):
50+
self.model.layers[layer_id].mlp.experts.clear_moe_load()
51+
52+
53+
def model_register(model, model_config):
54+
model.get_expert_map = types.MethodType(get_expert_map, model)
55+
model.get_log2phy_map = types.MethodType(get_log2phy_map, model)
56+
model.get_all_expert_map = types.MethodType(get_all_expert_map, model)
57+
model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model)
58+
model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model)
59+
60+
config = model_config.hf_config
61+
62+
if config.model_type == "qwen3_moe":
63+
model.num_moe_layers = config.num_hidden_layers
64+
elif config.model_type == "deepseek_v2":
65+
num_dense_layers = config.first_k_dense_replace
66+
model.num_moe_layers = config.num_hidden_layers - num_dense_layers
67+
else:
68+
raise NotImplementedError("EPLB is not supported.")

vllm_ascend/models/deepseek_v2.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -830,8 +830,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
830830
quant_config = vllm_config.quant_config
831831
self.config = config
832832
self.quant_config = quant_config
833-
self.num_dense_layers = self.config.first_k_dense_replace
834-
self.num_moe_layers = self.config.num_hidden_layers - self.num_dense_layers
835833
self.model = CustomDeepseekV2Model(vllm_config=vllm_config,
836834
prefix=maybe_prefix(
837835
prefix, "model"))
@@ -870,34 +868,6 @@ def load_weights(self, weights: Iterable[tuple[str,
870868

871869
return loaded_params
872870

873-
def get_expert_map(self, layer_id):
874-
return self.model.layers[layer_id].mlp.experts.get_map()
875-
876-
def get_log2phy_map(self, layer_id):
877-
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
878-
879-
def get_all_expert_map(self, num_moe_layers):
880-
all_loads = []
881-
for layer_id in range(num_moe_layers):
882-
load_tensor = self.get_expert_map(
883-
layer_id + self.num_dense_layers) # (num_experts_per_layer,)
884-
all_loads.append(load_tensor)
885-
886-
return torch.stack(all_loads, dim=0)
887-
888-
def get_all_moe_loads(self):
889-
all_moe_loads = torch.stack(
890-
[self.model.layers[layer_id + self.num_dense_layers].mlp.experts.moe_load \
891-
for layer_id in range(self.num_moe_layers)],
892-
dim=0
893-
)
894-
return all_moe_loads
895-
896-
def clear_all_moe_loads(self):
897-
for layer_id in range(self.num_moe_layers):
898-
self.model.layers[
899-
layer_id + self.num_dense_layers].mlp.experts.clear_moe_load()
900-
901871

902872
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
903873
pass

vllm_ascend/ops/fused_moe.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,11 @@ def fused_experts_with_mc2(
141141
is_torchair: bool = False,
142142
hidden_states_for_share: Optional[Any] = None,
143143
mc2_mask: Optional[torch.Tensor] = None,
144+
log2phy: Optional[torch.Tensor] = None,
145+
global_redundant_expert_num: int = 0
144146
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
147+
if log2phy is not None:
148+
topk_ids = log2phy[topk_ids]
145149
quant_mode = 0
146150
ep_group = get_mc2_group()
147151
ep_rank_id = ep_group.rank_in_group
@@ -163,7 +167,7 @@ def fused_experts_with_mc2(
163167

164168
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
165169

166-
moe_expert_num = len(expert_map)
170+
moe_expert_num = len(expert_map) + global_redundant_expert_num
167171
kwargs_mc2 = {
168172
"x": hidden_states,
169173
"expert_ids": topk_ids,
@@ -349,17 +353,16 @@ def apply_mlp(
349353

350354
# currently expert parallelism implemented with all2all
351355
# is under-optimized.
352-
def fused_experts_with_all2all(
353-
hidden_states: torch.Tensor,
354-
w1: torch.Tensor,
355-
w2: torch.Tensor,
356-
topk_weights: torch.Tensor,
357-
topk_ids: torch.Tensor,
358-
top_k: int,
359-
expert_map: torch.Tensor = None,
360-
ep_group: GroupCoordinator = None,
361-
max_num_tokens: Optional[int] = None,
362-
):
356+
def fused_experts_with_all2all(hidden_states: torch.Tensor,
357+
w1: torch.Tensor,
358+
w2: torch.Tensor,
359+
topk_weights: torch.Tensor,
360+
topk_ids: torch.Tensor,
361+
top_k: int,
362+
expert_map: torch.Tensor = None,
363+
ep_group: GroupCoordinator = None,
364+
max_num_tokens: Optional[int] = None,
365+
global_redundant_expert_num: int = 0):
363366
original_shape = hidden_states.shape
364367
if len(original_shape) == 3:
365368
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
@@ -369,7 +372,7 @@ def fused_experts_with_all2all(
369372
device = hidden_states.device
370373

371374
if expert_map is not None:
372-
global_num_experts = len(expert_map)
375+
global_num_experts = len(expert_map) + global_redundant_expert_num
373376
local_num_experts = global_num_experts // ep_group.world_size
374377
row_idx_len = num_tokens * top_k
375378
row_idx = (torch.arange(0,
@@ -639,7 +642,10 @@ def fused_experts_with_all2allv(
639642
hidden_states: torch.Tensor,
640643
w1: torch.Tensor,
641644
w2: torch.Tensor,
645+
log2phy: Optional[torch.Tensor] = None,
642646
):
647+
if log2phy is not None:
648+
routing_map = log2phy[routing_map]
643649
# Enable moe alltoallv, it's a balanced policy for precision and efficiency.
644650
(share_experts_output, dispatched_input,
645651
tokens_per_expert) = (token_dispatcher.token_permutation(
@@ -824,8 +830,8 @@ def fused_experts(
824830
expanded_src_to_dst_row=expanded_row_idx,
825831
export_for_source_row=topk_ids,
826832
)
827-
828-
return final_hidden_states
833+
group_list_type = 0
834+
return final_hidden_states, expert_tokens, group_list_type
829835

830836

831837
def native_grouped_topk(
@@ -1015,6 +1021,8 @@ def apply(
10151021
enable_force_load_balance: bool = False,
10161022
hidden_states_for_share: Optional[Any] = None,
10171023
shared_experts: Optional[Any] = None,
1024+
log2phy: Optional[Any] = None,
1025+
global_redundant_expert_num: int = 0,
10181026
**kwargs,
10191027
) -> torch.Tensor:
10201028

@@ -1071,6 +1079,8 @@ def apply(
10711079
is_torchair=self.torchair_graph_enabled,
10721080
hidden_states_for_share=hidden_states_for_share,
10731081
mc2_mask=mc2_mask,
1082+
log2phy=log2phy,
1083+
global_redundant_expert_num=global_redundant_expert_num,
10741084
)
10751085
elif fused_moe_state == FusedMoEState.AllGather:
10761086
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
@@ -1105,18 +1115,20 @@ def apply(
11051115
hidden_states=x,
11061116
w1=layer.w13_weight,
11071117
w2=layer.w2_weight,
1108-
)
1118+
log2phy=log2phy)
11091119
else:
11101120
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
1111-
return fused_experts_with_all2all(hidden_states=x,
1112-
w1=layer.w13_weight,
1113-
w2=layer.w2_weight,
1114-
topk_weights=topk_weights,
1115-
topk_ids=topk_ids,
1116-
top_k=top_k,
1117-
expert_map=expert_map,
1118-
ep_group=get_ep_group(),
1119-
max_num_tokens=max_num_tokens)
1121+
return fused_experts_with_all2all(
1122+
hidden_states=x,
1123+
w1=layer.w13_weight,
1124+
w2=layer.w2_weight,
1125+
topk_weights=topk_weights,
1126+
topk_ids=topk_ids,
1127+
top_k=top_k,
1128+
expert_map=expert_map,
1129+
ep_group=get_ep_group(),
1130+
max_num_tokens=max_num_tokens,
1131+
global_redundant_expert_num=global_redundant_expert_num)
11201132

11211133

11221134
class AscendFusedMoE(FusedMoE):
@@ -1273,6 +1285,10 @@ def __init__(
12731285
if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance(
12741286
self.quant_method, AscendUnquantizedFusedMoEMethod):
12751287
self.reduce_results = False
1288+
if expert_map_path and os.path.exists(expert_map_path):
1289+
self.global_num_experts = self.global_num_experts + self.global_redundant_expert_num
1290+
self.local_num_experts = self.global_num_experts // self.ep_size
1291+
12761292
moe_dispatcher_config = (
12771293
MoEDispatcherConfig().set_num_moe_experts(
12781294
self.global_num_experts).set_num_local_experts(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
from vllm_ascend.distributed.utils import is_lmhead_tp
8282
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
8383
from vllm_ascend.eplb.eplb_updator import EplbUpdator
84+
from vllm_ascend.eplb.utils import model_register
8485
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8586
from vllm_ascend.platform import NPUPlatform
8687
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
@@ -1897,6 +1898,8 @@ def load_model(self) -> None:
18971898

18981899
with DeviceMemoryProfiler() as m: # noqa: SIM117
18991900
self.model = get_model(vllm_config=self.vllm_config)
1901+
if self.dynamic_eplb:
1902+
model_register(self.model, self.model_config)
19001903
if hasattr(self, "drafter"):
19011904
logger.info("Loading drafter model...")
19021905
self.drafter.load_model()

0 commit comments

Comments
 (0)