Skip to content

Commit aff5189

Browse files
authored
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in W8A8_DYNAMIC quantized MoE layers (#2275)
### What this PR does / why we need it? Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion operation `GroupedMatmulSwigluQuant`. 1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py` 2. if in supported occasion, use fusion operation `npu_grouped_matmul_swiglu_quant` ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16` 1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output Token Throughput increased 27.35% <img width="3443" height="211" alt="image" src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e" /> 3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output Token Throughput increased 6.86% <img width="3443" height="211" alt="image" src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6" /> - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@6997a25 --------- Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com> Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
1 parent 37f5a29 commit aff5189

File tree

5 files changed

+257
-220
lines changed

5 files changed

+257
-220
lines changed

tests/ut/ops/test_fused_ops.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
3030
AscendUnquantizedFusedMoEMethod)
3131
from vllm_ascend.ops.layers.experts_selector import select_experts
32-
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
32+
from vllm_ascend.ops.layers.moe_mlp import cumsum_group_list, unified_apply_mlp
3333
from vllm_ascend.utils import AscendSocVersion, adapt_patch
3434

3535
adapt_patch(True)
@@ -524,6 +524,43 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,
524524
assert topk_ids.shape == (8, 2)
525525

526526

527+
class TestCumsumGroupList(TestBase):
528+
529+
def setUp(self):
530+
self.active_num = 8
531+
self.expert_num = 128
532+
self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64)
533+
self.experts[:self.active_num] = 1
534+
self.experts = self.experts[torch.randperm(self.expert_num)]
535+
self.group_list = self.experts.cumsum(dim=0)
536+
537+
def test_cumsum_group_list_with_type_0(self):
538+
group_list = self.experts.cumsum(dim=0)
539+
group_list_type = 0
540+
result = cumsum_group_list(group_list, group_list_type)
541+
self.assertTrue(torch.equal(result, self.group_list))
542+
543+
def test_cumsum_group_list_with_type_1(self):
544+
group_list = self.experts
545+
group_list_type = 1
546+
result = cumsum_group_list(group_list, group_list_type)
547+
self.assertTrue(torch.equal(result, self.group_list))
548+
549+
def test_cumsum_group_list_with_type_2(self):
550+
tokens = torch.arange(self.expert_num, dtype=torch.int64)
551+
group_list = torch.cat([
552+
tokens.reshape(self.expert_num, 1),
553+
self.experts.reshape(self.expert_num, 1)
554+
],
555+
dim=1)
556+
group_list_type = 2
557+
result = cumsum_group_list(group_list,
558+
group_list_type,
559+
active_num=self.active_num,
560+
expert_num=self.expert_num)
561+
self.assertTrue(torch.equal(result, self.group_list))
562+
563+
527564
class TestUnifiedApplyMLP(TestBase):
528565

529566
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
@@ -739,3 +776,68 @@ def test_unified_apply_mlp_without_quantization_310p(
739776

740777
self.assertEqual(result.shape, hidden_states.shape)
741778
self.assertEqual(result.dtype, torch.float16)
779+
780+
@patch("vllm_ascend.ops.layers.moe_mlp.get_forward_context")
781+
@patch("torch_npu.npu_grouped_matmul")
782+
@patch("torch_npu.npu_swiglu")
783+
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
784+
@patch("torch_npu.npu_dynamic_quant")
785+
def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
786+
self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
787+
mock_npu_swiglu, mock_npu_grouped_matmul,
788+
mock_get_forward_context):
789+
790+
mock_forward_context = MagicMock()
791+
mock_forward_context.with_quant = True
792+
mock_forward_context.fused_moe_state = "NOT_MC2"
793+
mock_get_forward_context.return_value = mock_forward_context
794+
795+
mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint(
796+
-128, 127, (10, 40),
797+
dtype=torch.int8), torch.rand(
798+
10, 1,
799+
dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32))
800+
mock_npu_grouped_matmul.side_effect = [[
801+
torch.randn(10, 20, dtype=torch.bfloat16)
802+
]]
803+
mock_npu_swiglu.return_value = torch.randn(10,
804+
40,
805+
dtype=torch.bfloat16)
806+
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
807+
127, (10, 40),
808+
dtype=torch.int8),
809+
torch.rand(10,
810+
1,
811+
dtype=torch.float32))
812+
813+
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
814+
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
815+
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
816+
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
817+
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
818+
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
819+
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
820+
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
821+
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
822+
823+
result = unified_apply_mlp(hidden_states=hidden_states,
824+
w1=w1,
825+
w1_scale=w1_scale,
826+
w2=w2,
827+
w2_scale=w2_scale,
828+
group_list=group_list,
829+
dynamic_scale=provided_dynamic_scale,
830+
group_list_type=1,
831+
w1_scale_bias=w1_scale_bias,
832+
w2_scale_bias=w2_scale_bias,
833+
topk_scales=None,
834+
with_quant=True,
835+
fusion=True)
836+
837+
mock_get_forward_context.assert_called()
838+
mock_npu_grouped_matmul.assert_called_once()
839+
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
840+
841+
self.assertTrue(mock_forward_context.with_quant)
842+
self.assertEqual(result.shape, hidden_states.shape)
843+
self.assertEqual(result.dtype, torch.bfloat16)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from unittest.mock import Mock, patch
2+
3+
import torch
4+
5+
from tests.ut.base import TestBase
6+
from vllm_ascend.quantization.w8a8_dynamic import \
7+
AscendW8A8DynamicFusedMoEMethod
8+
9+
10+
class TestAscendW8A8FusedMoEMethod(TestBase):
11+
num_experts = 8
12+
hidden_size = 128
13+
intermediate_size = 128
14+
15+
@patch("torch.distributed.get_rank")
16+
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
17+
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
18+
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
19+
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
20+
mock_get_mc2_group, mock_get_rank):
21+
mock_ep_group = Mock()
22+
mock_get_ep_group.return_value = mock_ep_group
23+
mock_ascend_config = Mock()
24+
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
25+
mock_get_ascend_config.return_value = mock_ascend_config
26+
mock_mc2_group = Mock(device_group=0)
27+
mock_get_mc2_group.return_value = mock_mc2_group
28+
mock_rank = Mock()
29+
mock_get_rank.return_value = mock_rank
30+
31+
self.quant_method = AscendW8A8DynamicFusedMoEMethod()
32+
33+
def test_get_weight(self):
34+
param_dict = self.quant_method.get_weight(self.num_experts,
35+
self.intermediate_size,
36+
self.hidden_size,
37+
torch.bfloat16)
38+
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
39+
self.assertEqual(
40+
param_dict["w13_weight"].shape,
41+
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
42+
43+
def test_get_dynamic_quant_param(self):
44+
param_dict = self.quant_method.get_dynamic_quant_param(
45+
self.num_experts, self.intermediate_size, self.hidden_size,
46+
torch.bfloat16)
47+
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
48+
self.assertEqual(param_dict["w13_weight_scale"].shape,
49+
(self.num_experts, 2 * self.intermediate_size, 1))

vllm_ascend/ops/fused_moe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
7070
shared_dequant_scale: Optional[Any] = None,
7171
mc2_mask: Optional[torch.Tensor] = None,
7272
apply_router_weight_on_input: bool = False,
73-
with_quant: bool = False):
73+
with_quant: bool = False,
74+
fusion_mlp: bool = False):
7475
token_dispatcher = get_forward_context().token_dispatcher
7576

7677
results = token_dispatcher.token_dispatch(
@@ -100,7 +101,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
100101
w1_scale_bias=w1_scale_bias,
101102
w2_scale_bias=w2_scale_bias,
102103
topk_scales=results.get("topk_scales"),
103-
with_quant=with_quant)
104+
with_quant=with_quant,
105+
fusion=fusion_mlp)
104106
final_hidden_states = token_dispatcher.token_combine(expert_output)
105107
return final_hidden_states
106108

vllm_ascend/ops/layers/moe_mlp.py

Lines changed: 96 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,52 @@
1818

1919
import torch
2020
import torch_npu
21+
from torch.nn.functional import pad
2122
from vllm.forward_context import get_forward_context
2223

2324
from vllm_ascend.ascend_forward_context import FusedMoEState
2425
from vllm_ascend.utils import dispose_tensor, is_310p
2526

2627

28+
def cumsum_group_list(group_list: torch.Tensor,
29+
group_list_type: int,
30+
active_num: int = 0,
31+
expert_num: int = 0) -> torch.Tensor:
32+
if group_list_type not in [0, 1, 2]:
33+
raise ValueError(
34+
f"group_list_type should be in [0, 1, 2], but received {group_list_type}"
35+
)
36+
37+
if group_list_type == 0:
38+
return group_list
39+
if group_list_type == 1:
40+
return group_list.cumsum(dim=0)
41+
42+
experts = pad(group_list[:, 0], (1, 0))
43+
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
44+
cumsum_group_list = torch.full(size=(expert_num, ),
45+
fill_value=active_num,
46+
dtype=group_list.dtype,
47+
device=group_list.device)
48+
49+
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
50+
if end > start:
51+
cumsum_group_list[start:end] = tokens[i]
52+
53+
return cumsum_group_list
54+
55+
2756
def quant_apply_mlp(hidden_states: torch.Tensor,
2857
w1: torch.Tensor,
2958
w1_scale: torch.Tensor,
3059
w2: torch.Tensor,
3160
w2_scale: torch.Tensor,
3261
group_list: torch.Tensor,
33-
dynamic_scale: torch.Tensor = None,
3462
group_list_type: int = 1,
63+
dynamic_scale: torch.Tensor = None,
3564
w1_scale_bias: torch.Tensor = None,
36-
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
65+
w2_scale_bias: torch.Tensor = None,
66+
fusion: bool = False) -> torch.Tensor:
3767
if dynamic_scale is None:
3868
unquantized_hidden_states = hidden_states
3969
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
@@ -49,31 +79,38 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
4979

5080
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
5181
if w1_scale_bias is None and is_mc2:
52-
w1_scale = w1_scale.to(torch.float32)
53-
54-
# gmm1: gate_up_proj
55-
hidden_states = torch_npu.npu_grouped_matmul(
56-
x=[hidden_states],
57-
weight=[w1],
58-
split_item=3,
59-
group_list_type=group_list_type,
60-
group_type=0,
61-
group_list=group_list,
62-
output_dtype=torch.int32)[0]
63-
64-
# act_fn: swiglu
65-
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
66-
x=hidden_states,
67-
weight_scale=w1_scale,
68-
activation_scale=pertoken_scale,
69-
bias=None,
70-
quant_scale=None,
71-
quant_offset=None,
72-
group_index=group_list,
73-
activate_left=True,
74-
quant_mode=1,
75-
)
76-
82+
if w1_scale.dtype != torch.float32:
83+
w1_scale = w1_scale.to(torch.float32)
84+
if fusion:
85+
# gmm1: gate_up_proj & act_fn: swiglu
86+
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
87+
x=hidden_states,
88+
weight=w1,
89+
group_list=cumsum_group_list(group_list, group_list_type),
90+
weight_scale=w1_scale,
91+
x_scale=pertoken_scale)
92+
else:
93+
# gmm1: gate_up_proj
94+
hidden_states = torch_npu.npu_grouped_matmul(
95+
x=[hidden_states],
96+
weight=[w1],
97+
split_item=3,
98+
group_list_type=group_list_type,
99+
group_type=0,
100+
group_list=group_list,
101+
output_dtype=torch.int32)[0]
102+
# act_fn: swiglu
103+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
104+
x=hidden_states,
105+
weight_scale=w1_scale,
106+
activation_scale=pertoken_scale,
107+
bias=None,
108+
quant_scale=None,
109+
quant_offset=None,
110+
group_index=group_list,
111+
activate_left=True,
112+
quant_mode=1,
113+
)
77114
# gmm2: down_proj
78115
hidden_states = torch_npu.npu_grouped_matmul(
79116
x=[hidden_states],
@@ -92,29 +129,37 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
92129
[group_list[:1],
93130
torch.diff(group_list, dim=0)])
94131
group_list_type = 1
95-
bias1 = [w1_scale_bias]
132+
bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
96133
bias2 = [w2_scale_bias]
97134
# TODO w4a8 scene: dynamic acquisition of dtype in the future
98135
_output_dtype = torch.bfloat16
99136

100-
# gmm1: gate_up_proj
101-
hidden_states = torch_npu.npu_grouped_matmul(
102-
x=[hidden_states],
103-
weight=[w1],
104-
scale=[w1_scale],
105-
bias=bias1,
106-
per_token_scale=[pertoken_scale],
107-
split_item=2,
108-
group_list_type=group_list_type,
109-
group_type=0,
110-
group_list=group_list,
111-
output_dtype=_output_dtype)[0]
112-
113-
# act_fn: swiglu
114-
hidden_states = torch_npu.npu_swiglu(hidden_states)
115-
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
116-
hidden_states)
117-
137+
if fusion:
138+
# gmm1: gate_up_proj & act_fn: swiglu
139+
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
140+
x=hidden_states,
141+
weight=w1,
142+
bias=bias1,
143+
group_list=cumsum_group_list(group_list, group_list_type),
144+
weight_scale=w1_scale,
145+
x_scale=pertoken_scale)
146+
else:
147+
# gmm1: gate_up_proj
148+
hidden_states = torch_npu.npu_grouped_matmul(
149+
x=[hidden_states],
150+
weight=[w1],
151+
scale=[w1_scale.to(w2_scale.dtype)],
152+
bias=bias1,
153+
per_token_scale=[pertoken_scale],
154+
split_item=2,
155+
group_list_type=group_list_type,
156+
group_type=0,
157+
group_list=group_list,
158+
output_dtype=_output_dtype)[0]
159+
# act_fn: swiglu
160+
hidden_states = torch_npu.npu_swiglu(hidden_states)
161+
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
162+
hidden_states)
118163
# gmm2: down_proj
119164
hidden_states = torch_npu.npu_grouped_matmul(
120165
x=[hidden_states],
@@ -127,6 +172,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
127172
group_type=0,
128173
group_list=group_list,
129174
output_dtype=_output_dtype)[0]
175+
130176
return hidden_states
131177

132178

@@ -178,7 +224,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
178224
w1_scale_bias: torch.Tensor = None,
179225
w2_scale_bias: torch.Tensor = None,
180226
topk_scales: Optional[torch.Tensor] = None,
181-
with_quant: bool = False) -> torch.Tensor:
227+
with_quant: bool = False,
228+
fusion: bool = False) -> torch.Tensor:
182229
if with_quant:
183230
return quant_apply_mlp(hidden_states=hidden_states,
184231
w1=w1,
@@ -189,7 +236,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
189236
dynamic_scale=dynamic_scale,
190237
group_list_type=group_list_type,
191238
w1_scale_bias=w1_scale_bias,
192-
w2_scale_bias=w2_scale_bias)
239+
w2_scale_bias=w2_scale_bias,
240+
fusion=fusion)
193241
else:
194242
return unquant_apply_mlp(hidden_states=hidden_states,
195243
w1=w1,

0 commit comments

Comments
 (0)