18
18
19
19
import torch
20
20
import torch_npu
21
+ from torch .nn .functional import pad
21
22
from vllm .forward_context import get_forward_context
22
23
23
24
from vllm_ascend .ascend_forward_context import FusedMoEState
24
25
from vllm_ascend .utils import dispose_tensor , is_310p
25
26
26
27
28
+ def cumsum_group_list (group_list : torch .Tensor ,
29
+ group_list_type : int ,
30
+ active_num : int = 0 ,
31
+ expert_num : int = 0 ) -> torch .Tensor :
32
+ if group_list_type not in [0 , 1 , 2 ]:
33
+ raise ValueError (
34
+ f"group_list_type should be in [0, 1, 2], but received { group_list_type } "
35
+ )
36
+
37
+ if group_list_type == 0 :
38
+ return group_list
39
+ if group_list_type == 1 :
40
+ return group_list .cumsum (dim = 0 )
41
+
42
+ experts = pad (group_list [:, 0 ], (1 , 0 ))
43
+ tokens = pad (group_list [:, 1 ].cumsum (dim = 0 ), (1 , 0 ))
44
+ cumsum_group_list = torch .full (size = (expert_num , ),
45
+ fill_value = active_num ,
46
+ dtype = group_list .dtype ,
47
+ device = group_list .device )
48
+
49
+ for i , (start , end ) in enumerate (zip (experts [:- 1 ], experts [1 :])):
50
+ if end > start :
51
+ cumsum_group_list [start :end ] = tokens [i ]
52
+
53
+ return cumsum_group_list
54
+
55
+
27
56
def quant_apply_mlp (hidden_states : torch .Tensor ,
28
57
w1 : torch .Tensor ,
29
58
w1_scale : torch .Tensor ,
30
59
w2 : torch .Tensor ,
31
60
w2_scale : torch .Tensor ,
32
61
group_list : torch .Tensor ,
33
- dynamic_scale : torch .Tensor = None ,
34
62
group_list_type : int = 1 ,
63
+ dynamic_scale : torch .Tensor = None ,
35
64
w1_scale_bias : torch .Tensor = None ,
36
- w2_scale_bias : torch .Tensor = None ) -> torch .Tensor :
65
+ w2_scale_bias : torch .Tensor = None ,
66
+ fusion : bool = False ) -> torch .Tensor :
37
67
if dynamic_scale is None :
38
68
unquantized_hidden_states = hidden_states
39
69
hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (
@@ -49,31 +79,38 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
49
79
50
80
is_mc2 = get_forward_context ().fused_moe_state == FusedMoEState .MC2
51
81
if w1_scale_bias is None and is_mc2 :
52
- w1_scale = w1_scale .to (torch .float32 )
53
-
54
- # gmm1: gate_up_proj
55
- hidden_states = torch_npu .npu_grouped_matmul (
56
- x = [hidden_states ],
57
- weight = [w1 ],
58
- split_item = 3 ,
59
- group_list_type = group_list_type ,
60
- group_type = 0 ,
61
- group_list = group_list ,
62
- output_dtype = torch .int32 )[0 ]
63
-
64
- # act_fn: swiglu
65
- hidden_states , swiglu_out_scale = torch_npu .npu_dequant_swiglu_quant (
66
- x = hidden_states ,
67
- weight_scale = w1_scale ,
68
- activation_scale = pertoken_scale ,
69
- bias = None ,
70
- quant_scale = None ,
71
- quant_offset = None ,
72
- group_index = group_list ,
73
- activate_left = True ,
74
- quant_mode = 1 ,
75
- )
76
-
82
+ if w1_scale .dtype != torch .float32 :
83
+ w1_scale = w1_scale .to (torch .float32 )
84
+ if fusion :
85
+ # gmm1: gate_up_proj & act_fn: swiglu
86
+ hidden_states , swiglu_out_scale , _ = torch_npu .npu_grouped_matmul_swiglu_quant (
87
+ x = hidden_states ,
88
+ weight = w1 ,
89
+ group_list = cumsum_group_list (group_list , group_list_type ),
90
+ weight_scale = w1_scale ,
91
+ x_scale = pertoken_scale )
92
+ else :
93
+ # gmm1: gate_up_proj
94
+ hidden_states = torch_npu .npu_grouped_matmul (
95
+ x = [hidden_states ],
96
+ weight = [w1 ],
97
+ split_item = 3 ,
98
+ group_list_type = group_list_type ,
99
+ group_type = 0 ,
100
+ group_list = group_list ,
101
+ output_dtype = torch .int32 )[0 ]
102
+ # act_fn: swiglu
103
+ hidden_states , swiglu_out_scale = torch_npu .npu_dequant_swiglu_quant (
104
+ x = hidden_states ,
105
+ weight_scale = w1_scale ,
106
+ activation_scale = pertoken_scale ,
107
+ bias = None ,
108
+ quant_scale = None ,
109
+ quant_offset = None ,
110
+ group_index = group_list ,
111
+ activate_left = True ,
112
+ quant_mode = 1 ,
113
+ )
77
114
# gmm2: down_proj
78
115
hidden_states = torch_npu .npu_grouped_matmul (
79
116
x = [hidden_states ],
@@ -92,29 +129,37 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
92
129
[group_list [:1 ],
93
130
torch .diff (group_list , dim = 0 )])
94
131
group_list_type = 1
95
- bias1 = [w1_scale_bias ]
132
+ bias1 = [w1_scale_bias ] if not fusion else w1_scale_bias
96
133
bias2 = [w2_scale_bias ]
97
134
# TODO w4a8 scene: dynamic acquisition of dtype in the future
98
135
_output_dtype = torch .bfloat16
99
136
100
- # gmm1: gate_up_proj
101
- hidden_states = torch_npu .npu_grouped_matmul (
102
- x = [hidden_states ],
103
- weight = [w1 ],
104
- scale = [w1_scale ],
105
- bias = bias1 ,
106
- per_token_scale = [pertoken_scale ],
107
- split_item = 2 ,
108
- group_list_type = group_list_type ,
109
- group_type = 0 ,
110
- group_list = group_list ,
111
- output_dtype = _output_dtype )[0 ]
112
-
113
- # act_fn: swiglu
114
- hidden_states = torch_npu .npu_swiglu (hidden_states )
115
- hidden_states , swiglu_out_scale = torch_npu .npu_dynamic_quant (
116
- hidden_states )
117
-
137
+ if fusion :
138
+ # gmm1: gate_up_proj & act_fn: swiglu
139
+ hidden_states , swiglu_out_scale , _ = torch_npu .npu_grouped_matmul_swiglu_quant (
140
+ x = hidden_states ,
141
+ weight = w1 ,
142
+ bias = bias1 ,
143
+ group_list = cumsum_group_list (group_list , group_list_type ),
144
+ weight_scale = w1_scale ,
145
+ x_scale = pertoken_scale )
146
+ else :
147
+ # gmm1: gate_up_proj
148
+ hidden_states = torch_npu .npu_grouped_matmul (
149
+ x = [hidden_states ],
150
+ weight = [w1 ],
151
+ scale = [w1_scale .to (w2_scale .dtype )],
152
+ bias = bias1 ,
153
+ per_token_scale = [pertoken_scale ],
154
+ split_item = 2 ,
155
+ group_list_type = group_list_type ,
156
+ group_type = 0 ,
157
+ group_list = group_list ,
158
+ output_dtype = _output_dtype )[0 ]
159
+ # act_fn: swiglu
160
+ hidden_states = torch_npu .npu_swiglu (hidden_states )
161
+ hidden_states , swiglu_out_scale = torch_npu .npu_dynamic_quant (
162
+ hidden_states )
118
163
# gmm2: down_proj
119
164
hidden_states = torch_npu .npu_grouped_matmul (
120
165
x = [hidden_states ],
@@ -127,6 +172,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
127
172
group_type = 0 ,
128
173
group_list = group_list ,
129
174
output_dtype = _output_dtype )[0 ]
175
+
130
176
return hidden_states
131
177
132
178
@@ -178,7 +224,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
178
224
w1_scale_bias : torch .Tensor = None ,
179
225
w2_scale_bias : torch .Tensor = None ,
180
226
topk_scales : Optional [torch .Tensor ] = None ,
181
- with_quant : bool = False ) -> torch .Tensor :
227
+ with_quant : bool = False ,
228
+ fusion : bool = False ) -> torch .Tensor :
182
229
if with_quant :
183
230
return quant_apply_mlp (hidden_states = hidden_states ,
184
231
w1 = w1 ,
@@ -189,7 +236,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
189
236
dynamic_scale = dynamic_scale ,
190
237
group_list_type = group_list_type ,
191
238
w1_scale_bias = w1_scale_bias ,
192
- w2_scale_bias = w2_scale_bias )
239
+ w2_scale_bias = w2_scale_bias ,
240
+ fusion = fusion )
193
241
else :
194
242
return unquant_apply_mlp (hidden_states = hidden_states ,
195
243
w1 = w1 ,
0 commit comments