Skip to content

Commit e07bd93

Browse files
committed
Add arbitrary padding to handle extreme inbalance case.
1 parent 670cbd9 commit e07bd93

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_stable_unzip.cu

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
#ifndef MAX_NUM_EXPERTS
2121
#define MAX_NUM_EXPERTS 32
2222
#endif
23+
24+
typedef struct __align__(16){
25+
int data[MAX_NUM_EXPERTS];
26+
}expert_base_offset;
27+
2328
// 多阶段算法,控制每block处理的行数来权衡额外开销
2429
// 首先解析routemap来更新专家当前所收到的token数,然后check前一个block给的前缀和并更新给下一个block
2530
// 随后,目的行号的信息已获取,立即开始搬运工作,直至任务完全完成
@@ -29,28 +34,28 @@ __global__ void tokens_unzip_stable_kernel(
2934
const routemap_T *__restrict__ routemap_topk,
3035
const probs_T *__restrict__ probs_topk,
3136
const float *__restrict__ XScale,
37+
const expert_base_offset expert_base_offset,
3238
X_T *__restrict__ X_unzipped,
3339
int *__restrict__ zipped_expertwise_rowmap,
3440
probs_T *__restrict__ probs_unzipped,
3541
float *__restrict__ XScale_unzipped,
3642
int *global_expertwise_block_cumsum,
3743
const int total_zipped_tokens_num,
38-
const int max_tokens_per_expert,
3944
const int token_length,
4045
const int scale_length,
4146
const int num_experts,
4247
const int topk) {
4348
const int block_row_base = blockIdx.x * CUMSUM_BLOCK_SIZE;
4449
int cumsum_offset[MAX_NUM_EXPERTS];
45-
int expert_offset[MAX_NUM_EXPERTS];
50+
int local_expert_offsets[MAX_NUM_EXPERTS];
4651
int local_cumsum[MAX_NUM_EXPERTS];
4752
#pragma unroll
4853
for (int i = 0; i < num_experts; i++) {
4954
cumsum_offset[i] =
5055
(blockIdx.x == 0)
5156
? 0
5257
: CUMSUM_INVALID_TAG; // 除了第0个block,其他的都以非法值初始化,因为atomic忙等要用
53-
expert_offset[i] = i * max_tokens_per_expert;
58+
local_expert_offsets[i] = expert_base_offset.data[i];
5459
local_cumsum[i] = 0;
5560
}
5661
const int base_row_idx = blockIdx.x * CUMSUM_BLOCK_SIZE;
@@ -80,7 +85,7 @@ __global__ void tokens_unzip_stable_kernel(
8085
const int expert = routemap_topk[row * topk + k];
8186
if (expert == -1) continue;
8287
local_expert_rowmap[internal_row][expert] =
83-
local_cumsum[expert] + expert_offset[expert];
88+
local_cumsum[expert] + local_expert_offsets[expert];
8489
local_expert_probs[internal_row][expert] = probs_topk[row * topk + k];
8590
local_cumsum[expert] += 1;
8691
}
@@ -149,6 +154,7 @@ void dispatch_tokens_unzip_stable(
149154
const paddle::Tensor &expert_routemap_topk,
150155
const paddle::Tensor &expert_prob_topk,
151156
const paddle::optional<paddle::Tensor> &XScale,
157+
const expert_base_offset &expert_offsets,
152158
paddle::Tensor &X_unzipped,
153159
paddle::Tensor &zipped_expertwise_rowmap,
154160
paddle::Tensor &token_prob_unzipped,
@@ -158,7 +164,6 @@ void dispatch_tokens_unzip_stable(
158164
const int token_length,
159165
const int topk,
160166
const int num_experts,
161-
const int max_tokens_per_expert,
162167
const int scale_length) {
163168
dim3 grid, block;
164169
grid.x =
@@ -177,13 +182,13 @@ void dispatch_tokens_unzip_stable(
177182
GET_DATA(expert_routemap_topk, INT_T), \
178183
GET_DATA(expert_prob_topk, PROB_T), \
179184
XScale ? XScale->data<float>() : nullptr, \
185+
expert_offsets, \
180186
GET_DATA(X_unzipped, TOKEN_T), \
181187
GET_DATA(zipped_expertwise_rowmap, INT_T), \
182188
GET_DATA(token_prob_unzipped, PROB_T), \
183189
XScale_unzipped.data<float>(), \
184190
global_expertwise_block_cumsum.data<int>(), \
185191
total_zipped_tokens_num, \
186-
max_tokens_per_expert, \
187192
token_length, \
188193
scale_length, \
189194
num_experts, \
@@ -228,7 +233,8 @@ std::vector<paddle::Tensor> tokens_unzip_stable(
228233
const paddle::Tensor &expert_prob_topk,
229234
const int &topk,
230235
const int &num_experts,
231-
const int &max_tokens_per_expert_in) {
236+
const std::vector<int> &tokens_per_expert,
237+
const int padding_multiplex) {
232238
// --------------------- 输入检查与解析 --------------------
233239
PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16 ||
234240
X.dtype() == paddle::DataType::FLOAT8_E4M3FN);
@@ -241,9 +247,23 @@ std::vector<paddle::Tensor> tokens_unzip_stable(
241247
const int rows = X.shape()[0]; // 一般为seqlen
242248
const int cols = X.shape()[1]; // 一般为7168
243249
const int quanted_cols = (XScale) ? XScale->shape()[1] : 0;
250+
/*
244251
const int max_tokens_per_expert =
245252
((max_tokens_per_expert_in + 127) / 128) * 128;
246253
const int output_rows = num_experts * max_tokens_per_expert;
254+
*/
255+
expert_base_offset expert_offset;
256+
int tokens_cumulated = 0;
257+
for(int i = 0; i < MAX_NUM_EXPERTS; i++){
258+
if(i < num_experts){
259+
expert_offset.data[i] = tokens_cumulated;
260+
tokens_cumulated += ((tokens_per_expert[i] + padding_multiplex - 1) / padding_multiplex) * padding_multiplex;
261+
}else{
262+
expert_offset.data[i] = 0;
263+
}
264+
}
265+
266+
const int output_rows = tokens_cumulated;
247267
//------------------------ 输出缓冲区分配 ------------------------
248268
paddle::Tensor X_unzipped, XScale_unzipped, zipped_expertwise_rowmap,
249269
token_prob_unzipped;
@@ -317,6 +337,7 @@ std::vector<paddle::Tensor> tokens_unzip_stable(
317337
expert_routemap_topk,
318338
expert_prob_topk,
319339
XScale,
340+
expert_offset,
320341
X_unzipped,
321342
zipped_expertwise_rowmap,
322343
token_prob_unzipped,
@@ -326,7 +347,6 @@ std::vector<paddle::Tensor> tokens_unzip_stable(
326347
cols,
327348
topk,
328349
num_experts,
329-
max_tokens_per_expert,
330350
quanted_cols);
331351
return {X_unzipped,
332352
zipped_expertwise_rowmap,
@@ -343,7 +363,7 @@ PD_BUILD_OP(tokens_unzip_stable)
343363
"zipped_expertwise_rowmap",
344364
"token_prob_unzipped",
345365
paddle::Optional("XScale_unzipped")})
346-
.Attrs({"topk: int", "num_experts: int", "max_tokens_per_expert: int"})
366+
.Attrs({"topk: int", "num_experts: int","tokens_per_expert: std::vector<int>","padding_multiplex: int"})
347367
.SetKernelFn(PD_KERNEL(tokens_unzip_stable));
348368

349369

tests/ops/test_unzip_zip.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def fabricate_dispatch_result(
5151
valid_experts, bins=num_experts, min=0, max=num_experts - 1
5252
)
5353
expert_counts = paddle.cast(expert_counts, "int32")
54-
print("expert counts: ", expert_counts.numpy())
55-
max_tokens_per_expert = expert_counts.max().item()
54+
expert_counts = list(expert_counts)
55+
print("expert counts: ", expert_counts)
5656

5757
return (
5858
tokens,
5959
tokens_scale,
6060
dispatched_indices,
6161
dispatched_probs,
62-
max_tokens_per_expert,
62+
expert_counts,
6363
)
6464

6565

@@ -75,7 +75,7 @@ def test_unzip_zip():
7575
SEQLEN = 16384
7676
TOKEN_LEN = 7168
7777
for dt in ["bfloat16"]:
78-
for expert_num in [2, 4, 8, 16, 32]:
78+
for expert_num in [4, 8, 16, 32]:
7979
for topk in [4, 8, 12]:
8080
print("###################################")
8181
print(
@@ -88,7 +88,7 @@ def test_unzip_zip():
8888
tokens_scale,
8989
dispatched_indices,
9090
dispatched_probs,
91-
max_tokens_per_expert,
91+
expert_tokens_count,
9292
) = fabricate_dispatch_result(
9393
SEQLEN,
9494
TOKEN_LEN,
@@ -111,7 +111,8 @@ def test_unzip_zip():
111111
dispatched_probs,
112112
topk=topk,
113113
num_experts=expert_num,
114-
max_tokens_per_expert=max_tokens_per_expert,
114+
tokens_per_expert=expert_tokens_count,
115+
padding_multiplex=128
115116
)
116117
tokens_recovered, probs_recovered = TDU.tokens_zip(
117118
(unzipped_tokens * unzipped_probs.unsqueeze(-1)).astype("bfloat16"),

0 commit comments

Comments
 (0)