20
20
#ifndef MAX_NUM_EXPERTS
21
21
#define MAX_NUM_EXPERTS 32
22
22
#endif
23
+
24
+ typedef struct __align__ (16 ){
25
+ int data[MAX_NUM_EXPERTS];
26
+ }expert_base_offset;
27
+
23
28
// 多阶段算法,控制每block处理的行数来权衡额外开销
24
29
// 首先解析routemap来更新专家当前所收到的token数,然后check前一个block给的前缀和并更新给下一个block
25
30
// 随后,目的行号的信息已获取,立即开始搬运工作,直至任务完全完成
@@ -29,28 +34,28 @@ __global__ void tokens_unzip_stable_kernel(
29
34
const routemap_T *__restrict__ routemap_topk,
30
35
const probs_T *__restrict__ probs_topk,
31
36
const float *__restrict__ XScale,
37
+ const expert_base_offset expert_base_offset,
32
38
X_T *__restrict__ X_unzipped,
33
39
int *__restrict__ zipped_expertwise_rowmap,
34
40
probs_T *__restrict__ probs_unzipped,
35
41
float *__restrict__ XScale_unzipped,
36
42
int *global_expertwise_block_cumsum,
37
43
const int total_zipped_tokens_num,
38
- const int max_tokens_per_expert,
39
44
const int token_length,
40
45
const int scale_length,
41
46
const int num_experts,
42
47
const int topk) {
43
48
const int block_row_base = blockIdx .x * CUMSUM_BLOCK_SIZE;
44
49
int cumsum_offset[MAX_NUM_EXPERTS];
45
- int expert_offset [MAX_NUM_EXPERTS];
50
+ int local_expert_offsets [MAX_NUM_EXPERTS];
46
51
int local_cumsum[MAX_NUM_EXPERTS];
47
52
#pragma unroll
48
53
for (int i = 0 ; i < num_experts; i++) {
49
54
cumsum_offset[i] =
50
55
(blockIdx .x == 0 )
51
56
? 0
52
57
: CUMSUM_INVALID_TAG; // 除了第0个block,其他的都以非法值初始化,因为atomic忙等要用
53
- expert_offset [i] = i * max_tokens_per_expert ;
58
+ local_expert_offsets [i] = expert_base_offset. data [i] ;
54
59
local_cumsum[i] = 0 ;
55
60
}
56
61
const int base_row_idx = blockIdx .x * CUMSUM_BLOCK_SIZE;
@@ -80,7 +85,7 @@ __global__ void tokens_unzip_stable_kernel(
80
85
const int expert = routemap_topk[row * topk + k];
81
86
if (expert == -1 ) continue ;
82
87
local_expert_rowmap[internal_row][expert] =
83
- local_cumsum[expert] + expert_offset [expert];
88
+ local_cumsum[expert] + local_expert_offsets [expert];
84
89
local_expert_probs[internal_row][expert] = probs_topk[row * topk + k];
85
90
local_cumsum[expert] += 1 ;
86
91
}
@@ -149,6 +154,7 @@ void dispatch_tokens_unzip_stable(
149
154
const paddle::Tensor &expert_routemap_topk,
150
155
const paddle::Tensor &expert_prob_topk,
151
156
const paddle::optional<paddle::Tensor> &XScale,
157
+ const expert_base_offset &expert_offsets,
152
158
paddle::Tensor &X_unzipped,
153
159
paddle::Tensor &zipped_expertwise_rowmap,
154
160
paddle::Tensor &token_prob_unzipped,
@@ -158,7 +164,6 @@ void dispatch_tokens_unzip_stable(
158
164
const int token_length,
159
165
const int topk,
160
166
const int num_experts,
161
- const int max_tokens_per_expert,
162
167
const int scale_length) {
163
168
dim3 grid, block;
164
169
grid.x =
@@ -177,13 +182,13 @@ void dispatch_tokens_unzip_stable(
177
182
GET_DATA (expert_routemap_topk, INT_T), \
178
183
GET_DATA (expert_prob_topk, PROB_T), \
179
184
XScale ? XScale->data <float >() : nullptr , \
185
+ expert_offsets, \
180
186
GET_DATA (X_unzipped, TOKEN_T), \
181
187
GET_DATA (zipped_expertwise_rowmap, INT_T), \
182
188
GET_DATA (token_prob_unzipped, PROB_T), \
183
189
XScale_unzipped.data <float >(), \
184
190
global_expertwise_block_cumsum.data <int >(), \
185
191
total_zipped_tokens_num, \
186
- max_tokens_per_expert, \
187
192
token_length, \
188
193
scale_length, \
189
194
num_experts, \
@@ -228,7 +233,8 @@ std::vector<paddle::Tensor> tokens_unzip_stable(
228
233
const paddle::Tensor &expert_prob_topk,
229
234
const int &topk,
230
235
const int &num_experts,
231
- const int &max_tokens_per_expert_in) {
236
+ const std::vector<int > &tokens_per_expert,
237
+ const int padding_multiplex) {
232
238
// --------------------- 输入检查与解析 --------------------
233
239
PD_CHECK (X.dtype () == paddle::DataType::BFLOAT16 ||
234
240
X.dtype () == paddle::DataType::FLOAT8_E4M3FN);
@@ -241,9 +247,23 @@ std::vector<paddle::Tensor> tokens_unzip_stable(
241
247
const int rows = X.shape ()[0 ]; // 一般为seqlen
242
248
const int cols = X.shape ()[1 ]; // 一般为7168
243
249
const int quanted_cols = (XScale) ? XScale->shape ()[1 ] : 0 ;
250
+ /*
244
251
const int max_tokens_per_expert =
245
252
((max_tokens_per_expert_in + 127) / 128) * 128;
246
253
const int output_rows = num_experts * max_tokens_per_expert;
254
+ */
255
+ expert_base_offset expert_offset;
256
+ int tokens_cumulated = 0 ;
257
+ for (int i = 0 ; i < MAX_NUM_EXPERTS; i++){
258
+ if (i < num_experts){
259
+ expert_offset.data [i] = tokens_cumulated;
260
+ tokens_cumulated += ((tokens_per_expert[i] + padding_multiplex - 1 ) / padding_multiplex) * padding_multiplex;
261
+ }else {
262
+ expert_offset.data [i] = 0 ;
263
+ }
264
+ }
265
+
266
+ const int output_rows = tokens_cumulated;
247
267
// ------------------------ 输出缓冲区分配 ------------------------
248
268
paddle::Tensor X_unzipped, XScale_unzipped, zipped_expertwise_rowmap,
249
269
token_prob_unzipped;
@@ -317,6 +337,7 @@ std::vector<paddle::Tensor> tokens_unzip_stable(
317
337
expert_routemap_topk,
318
338
expert_prob_topk,
319
339
XScale,
340
+ expert_offset,
320
341
X_unzipped,
321
342
zipped_expertwise_rowmap,
322
343
token_prob_unzipped,
@@ -326,7 +347,6 @@ std::vector<paddle::Tensor> tokens_unzip_stable(
326
347
cols,
327
348
topk,
328
349
num_experts,
329
- max_tokens_per_expert,
330
350
quanted_cols);
331
351
return {X_unzipped,
332
352
zipped_expertwise_rowmap,
@@ -343,7 +363,7 @@ PD_BUILD_OP(tokens_unzip_stable)
343
363
" zipped_expertwise_rowmap" ,
344
364
" token_prob_unzipped" ,
345
365
paddle::Optional (" XScale_unzipped" )})
346
- .Attrs({" topk: int" , " num_experts: int" , " max_tokens_per_expert : int" })
366
+ .Attrs({" topk: int" , " num_experts: int" ," tokens_per_expert: std::vector<int> " , " padding_multiplex : int" })
347
367
.SetKernelFn(PD_KERNEL(tokens_unzip_stable));
348
368
349
369
0 commit comments