Skip to content

Commit a3c69a9

Browse files
committed
fix: fix minor bugs
Removes unnecessary weight transpose operations within the fused MoE expert function to improve performance. Refactors how quantization flags are passed for MoE communication primitives. Skips a W8A8 MoE test, as the required All-Gather communication operation does not yet support this quantization mode. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent de9d711 commit a3c69a9

File tree

4 files changed

+19
-17
lines changed

4 files changed

+19
-17
lines changed

tests/e2e/multicard/test_qwen3_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import os
2525

26+
import pytest
2627
from modelscope import snapshot_download # type: ignore
2728

2829
from tests.e2e.conftest import VllmRunner
@@ -107,4 +108,4 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH():
107108
tensor_parallel_size=2,
108109
enforce_eager=False,
109110
) as vllm_model:
110-
vllm_model.generate_greedy(example_prompts, max_tokens)
111+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def permute(
5454
topk_weights: torch.Tensor,
5555
expert_map: torch.Tensor,
5656
num_experts: int,
57-
use_a8: bool,
57+
apply_a8_quantization: bool,
5858
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
5959
"""Pre-process before MLP.
6060
@@ -65,6 +65,7 @@ def permute(
6565
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
6666
Mapping from global expert IDs to local expert IDs.
6767
num_experts (int): Number of local experts (experts on this device).
68+
apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
6869
6970
Returns:
7071
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
@@ -73,6 +74,8 @@ def permute(
7374
hidden_states based on topk_ids.
7475
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
7576
Number of tokens assigned to each expert.
77+
- dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
78+
Dynamic scale for each expert, used for quantization.
7679
- group_list_type (int): Type of group list, 0 for `cumsum`
7780
and 1 for `count`. This is mainly for `npu_grouped_matmul`
7881
to determine how to handle the output.
@@ -160,7 +163,7 @@ def permute(
160163
topk_weights: torch.Tensor,
161164
expert_map: torch.Tensor, # noqa: F841
162165
num_experts: int,
163-
use_a8: bool,
166+
apply_a8_quantization: bool,
164167
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
165168
num_tokens = hidden_states.shape[0]
166169

@@ -221,7 +224,7 @@ def permute(
221224
topk_weights: torch.Tensor,
222225
expert_map: torch.Tensor,
223226
num_experts: int,
224-
use_a8: bool,
227+
apply_a8_quantization: bool,
225228
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
226229
num_tokens = hidden_states.shape[0]
227230

@@ -378,7 +381,7 @@ def permute(
378381
topk_weights: torch.Tensor,
379382
expert_map: torch.Tensor,
380383
num_experts: int,
381-
use_a8: bool,
384+
apply_a8_quantization: bool,
382385
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
383386
# Store tensors needed for post_process
384387
self.topk_ids = topk_ids
@@ -392,7 +395,7 @@ def permute(
392395
"moe_expert_num": self.moe_config.num_experts,
393396
"global_bs": 0,
394397
"scales": None,
395-
"quant_mode": 2 if use_a8 else 0,
398+
"quant_mode": 2 if apply_a8_quantization else 0,
396399
"group_ep": self.mc2_comm_name,
397400
"ep_world_size": self.moe_config.ep_size,
398401
"ep_rank_id": self.moe_config.ep_rank,
@@ -536,13 +539,15 @@ def permute(
536539
topk_weights: torch.Tensor,
537540
expert_map: torch.Tensor,
538541
num_experts: int,
539-
use_a8: bool,
542+
apply_a8_quantization: bool,
540543
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
541-
results = self.token_dispatcher.token_dispatch(hidden_states,
542-
topk_weights,
543-
topk_ids,
544-
None,
545-
log2phy=None)
544+
results = self.token_dispatcher.token_dispatch(
545+
hidden_states,
546+
topk_weights,
547+
topk_ids,
548+
None,
549+
log2phy=None,
550+
with_quant=apply_a8_quantization)
546551
return results["hidden_states"], results["group_list"], results[
547552
"dynamic_scale"], results["group_list_type"]
548553

vllm_ascend/ops/common_fused_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,10 @@ def __init__(
287287
has_bias,
288288
)
289289

290-
with_quant = quant_config is not None
291290
setup_token_dispatchers(self.moe_config.ep_size,
292291
top_k=self.top_k,
293292
num_experts=self.global_num_experts,
294-
num_local_experts=self.local_num_experts,
295-
with_quant=with_quant)
293+
num_local_experts=self.local_num_experts)
296294

297295
self.moe_config.tp_group = get_tp_group()
298296
self.moe_config.dp_group = get_dp_group()

vllm_ascend/ops/fused_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ def fused_experts_moge(
230230
0, sorted_topk_ids).unsqueeze(-1)
231231
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
232232

233-
w1 = w1.transpose(1, 2)
234233
gate_up_out = torch_npu.npu_grouped_matmul(
235234
x=[sorted_hidden_states],
236235
weight=[w1],
@@ -247,7 +246,6 @@ def fused_experts_moge(
247246
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
248247
gate_up_out *= topk_scales
249248

250-
w2 = w2.transpose(1, 2)
251249
down_out_list = torch_npu.npu_grouped_matmul(
252250
x=[gate_up_out],
253251
weight=[w2],

0 commit comments

Comments
 (0)