Skip to content

Commit ab69953

Browse files
author
Zhaowu Pan
authored
Add FP32 support for zip op, optimize precision in bf16. (#10433)
1 parent d3ea4ec commit ab69953

File tree

2 files changed

+86
-29
lines changed

2 files changed

+86
-29
lines changed

slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_unzip_and_zip.cu

Lines changed: 85 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -292,36 +292,43 @@ __global__ void tokens_zip_kernel(
292292
x_offset < num_full_vec * vecSize;
293293
x_offset += thread_stride) {
294294
float2 sum = {0.0f, 0.0f};
295+
__nv_bfloat162 raw = {0,0};
296+
int aggreg_cnt = 0;
295297
__nv_bfloat162 *out_ptr = reinterpret_cast<__nv_bfloat162 *>(
296298
&zipped_tokens[this_row * token_length + x_offset]);
297299
#pragma unroll
298300
for (int expert = 0; expert < num_experts; ++expert) {
299301
const int fetch_row = local_row_fetchlist[expert];
300302
if (fetch_row < 0) continue;
303+
aggreg_cnt ++;
301304
// 手动类型提升
305+
raw = *reinterpret_cast<const __nv_bfloat162 *>(
306+
&unzipped_tokens[fetch_row * token_length + x_offset]);
302307
float2 token_vec =
303-
__bfloat1622float2(*reinterpret_cast<const __nv_bfloat162 *>(
304-
&unzipped_tokens[fetch_row * token_length + x_offset]));
308+
__bfloat1622float2(raw);
305309
sum.x = __fadd_rn(token_vec.x, sum.x);
306310
sum.y = __fadd_rn(token_vec.y, sum.y);
307311
}
308-
// 类型下降为原有精度
309-
*out_ptr = __float22bfloat162_rn(sum);
312+
// 选择性类型下降为原有精度
313+
*out_ptr = (aggreg_cnt > 1) ? __float22bfloat162_rn(sum) : raw;
310314
}
311315

312316
// 剩余元素处理
313317
for (int i = num_full_vec * vecSize + threadIdx.x; i < token_length;
314318
i += blockDim.x) {
315319
float sum = 0.0f;
320+
__nv_bfloat16 raw = 0;
321+
int aggreg_cnt = 0;
316322
#pragma unroll
317323
for (int expert = 0; expert < num_experts; ++expert) {
318324
int fetch_row = local_row_fetchlist[expert];
319325
if (fetch_row < 0) continue;
320-
float token_val =
321-
__bfloat162float(unzipped_tokens[fetch_row * token_length + i]);
326+
aggreg_cnt ++;
327+
raw = unzipped_tokens[fetch_row * token_length + i];
328+
float token_val = __bfloat162float(raw);
322329
sum = __fadd_rn(token_val, sum);
323330
}
324-
zipped_tokens[this_row * token_length + i] = __float2bfloat16_rn(sum);
331+
zipped_tokens[this_row * token_length + i] = (aggreg_cnt > 1)? __float2bfloat16_rn(sum) : raw;
325332
}
326333
} else {
327334
// ------------------------ BF16 intrinsics 加权累加 -----------------------
@@ -358,6 +365,55 @@ __global__ void tokens_zip_kernel(
358365
}
359366
}
360367
}
368+
template <int topk, int num_experts>
369+
__global__ void tokens_zip_kernel(
370+
const float*__restrict__ unzipped_tokens,
371+
const int *__restrict__ zipped_expertwise_rowmap,
372+
const int *__restrict__ expert_routemap_topk,
373+
const float *__restrict__ unzipped_token_probs,
374+
float *__restrict__ zipped_tokens,
375+
float *__restrict__ zipped_probs_topk,
376+
const int total_zipped_tokens_num,
377+
const int token_length) {
378+
const int this_row = blockIdx.x;
379+
if (this_row >= total_zipped_tokens_num) return;
380+
int local_row_fetchlist[num_experts];
381+
382+
// -------------------------初始化任务表 ------------------------
383+
#pragma unroll
384+
for (int expert = 0; expert < num_experts; ++expert) {
385+
const int fetch_row =
386+
zipped_expertwise_rowmap[this_row * num_experts + expert];
387+
local_row_fetchlist[expert] = fetch_row;
388+
}
389+
390+
#pragma unroll
391+
for (int k = 0; k < topk; ++k) {
392+
const int expert_idx = expert_routemap_topk[this_row * topk + k];
393+
if (expert_idx < 0) [[likely]]
394+
continue;
395+
const int expert_fetch_row = local_row_fetchlist[expert_idx];
396+
zipped_probs_topk[this_row * topk + k] =
397+
unzipped_token_probs[expert_fetch_row];
398+
}
399+
400+
const int thread_stride = blockDim.x;
401+
402+
// ------------------------ 手动混合精度 ---------------------------------
403+
// 齐整区域向量化搬移
404+
for (int x_offset = threadIdx.x; x_offset < token_length;
405+
x_offset += thread_stride) {
406+
float sum = 0.0f;
407+
#pragma unroll
408+
for (int expert = 0; expert < num_experts; ++expert) {
409+
const int fetch_row = local_row_fetchlist[expert];
410+
if (fetch_row < 0) continue;
411+
// 手动类型提升
412+
sum += unzipped_tokens[fetch_row * token_length + x_offset];
413+
}
414+
zipped_tokens[this_row * token_length + x_offset] = sum;
415+
}
416+
}
361417

362418
// ---------------------------- Dispatch ---------------------------------
363419
void dispatch_tokens_unzip(const paddle::Tensor &X,
@@ -435,17 +491,6 @@ void dispatch_tokens_unzip(const paddle::Tensor &X,
435491
#undef HANDLE_PROB_TYPE
436492
}
437493

438-
/*
439-
dispatch_tokens_zip(unzipped_tokens,
440-
zipped_expertwise_rowmap,
441-
expert_routemap_topk,
442-
unzipped_token_probs,
443-
zipped_tokens,
444-
zipped_probs_topk,
445-
total_zipped_tokens_num,
446-
num_experts,
447-
cols);
448-
*/
449494
void dispatch_tokens_zip(const paddle::Tensor &unzipped_tokens,
450495
const paddle::Tensor &zipped_expertwise_rowmap,
451496
const paddle::Tensor &expert_routemap_topk,
@@ -462,15 +507,27 @@ void dispatch_tokens_zip(const paddle::Tensor &unzipped_tokens,
462507

463508
// Map data types to C++ types
464509
if (topk == 8 && num_experts == 4) {
465-
tokens_zip_kernel<8, 4><<<grid, block, 0, unzipped_tokens.stream()>>>(
466-
unzipped_tokens.data<phi::bfloat16>(),
467-
zipped_expertwise_rowmap.data<int>(),
468-
expert_routemap_topk.data<int>(),
469-
unzipped_token_probs.data<float>(),
470-
zipped_tokens.data<phi::bfloat16>(),
471-
zipped_probs_topk.data<float>(),
472-
total_zipped_tokens_num,
473-
token_length);
510+
if (unzipped_tokens.dtype() == paddle::DataType::BFLOAT16){
511+
tokens_zip_kernel<8, 4><<<grid, block, 0, unzipped_tokens.stream()>>>(
512+
unzipped_tokens.data<phi::bfloat16>(),
513+
zipped_expertwise_rowmap.data<int>(),
514+
expert_routemap_topk.data<int>(),
515+
unzipped_token_probs.data<float>(),
516+
zipped_tokens.data<phi::bfloat16>(),
517+
zipped_probs_topk.data<float>(),
518+
total_zipped_tokens_num,
519+
token_length);
520+
}else if (unzipped_tokens.dtype() == paddle::DataType::FLOAT32){
521+
tokens_zip_kernel<8, 4><<<grid, block, 0, unzipped_tokens.stream()>>>(
522+
unzipped_tokens.data<float>(),
523+
zipped_expertwise_rowmap.data<int>(),
524+
expert_routemap_topk.data<int>(),
525+
unzipped_token_probs.data<float>(),
526+
zipped_tokens.data<float>(),
527+
zipped_probs_topk.data<float>(),
528+
total_zipped_tokens_num,
529+
token_length);
530+
}
474531
}
475532
}
476533

@@ -538,7 +595,7 @@ std::vector<paddle::Tensor> tokens_zip(
538595
const paddle::Tensor &unzipped_token_probs,
539596
const int &total_zipped_tokens_num,
540597
const int &num_experts) {
541-
PD_CHECK(unzipped_tokens.dtype() == paddle::DataType::BFLOAT16);
598+
PD_CHECK(unzipped_tokens.dtype() == paddle::DataType::BFLOAT16 || unzipped_tokens.dtype() == paddle::DataType::FLOAT32);
542599
const int rows = unzipped_tokens.shape()[0]; // seqlen
543600
const int cols = unzipped_tokens.shape()[1]; // 一般为7168
544601
const int topk = expert_routemap_topk.shape()[1]; // 一般为8

tests/ops/test_tokens_zip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def verify_tokens_unzip():
113113
unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, unzipped_expert_idx = TDU.tokens_unzip(tokens_zipped,routemap_topk, probs_topk,total_unzipped_tokens_num=total_unzipped_tokens_num, topk=topk, num_experts=expert_num)
114114

115115
# 本算子
116-
zipped_tokens, zipped_probs_topk = TDU.tokens_zip(unzipped_tokens, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens=seqlen, num_experts=expert_num)
116+
zipped_tokens, zipped_probs_topk = TDU.tokens_zip(unzipped_tokens.astype("float32"), zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens=seqlen, num_experts=expert_num)
117117
# ------------------------- 前向验证 ------------------------
118118
print("-------- Tokens unzipped by customed op: ------------")
119119
print(unzipped_tokens)

0 commit comments

Comments
 (0)