@@ -292,36 +292,43 @@ __global__ void tokens_zip_kernel(
292
292
x_offset < num_full_vec * vecSize;
293
293
x_offset += thread_stride) {
294
294
float2 sum = {0 .0f , 0 .0f };
295
+ __nv_bfloat162 raw = {0 ,0 };
296
+ int aggreg_cnt = 0 ;
295
297
__nv_bfloat162 *out_ptr = reinterpret_cast <__nv_bfloat162 *>(
296
298
&zipped_tokens[this_row * token_length + x_offset]);
297
299
#pragma unroll
298
300
for (int expert = 0 ; expert < num_experts; ++expert) {
299
301
const int fetch_row = local_row_fetchlist[expert];
300
302
if (fetch_row < 0 ) continue ;
303
+ aggreg_cnt ++;
301
304
// 手动类型提升
305
+ raw = *reinterpret_cast <const __nv_bfloat162 *>(
306
+ &unzipped_tokens[fetch_row * token_length + x_offset]);
302
307
float2 token_vec =
303
- __bfloat1622float2 (*reinterpret_cast <const __nv_bfloat162 *>(
304
- &unzipped_tokens[fetch_row * token_length + x_offset]));
308
+ __bfloat1622float2 (raw);
305
309
sum.x = __fadd_rn (token_vec.x , sum.x );
306
310
sum.y = __fadd_rn (token_vec.y , sum.y );
307
311
}
308
- // 类型下降为原有精度
309
- *out_ptr = __float22bfloat162_rn (sum);
312
+ // 选择性类型下降为原有精度
313
+ *out_ptr = (aggreg_cnt > 1 ) ? __float22bfloat162_rn (sum) : raw ;
310
314
}
311
315
312
316
// 剩余元素处理
313
317
for (int i = num_full_vec * vecSize + threadIdx .x ; i < token_length;
314
318
i += blockDim .x ) {
315
319
float sum = 0 .0f ;
320
+ __nv_bfloat16 raw = 0 ;
321
+ int aggreg_cnt = 0 ;
316
322
#pragma unroll
317
323
for (int expert = 0 ; expert < num_experts; ++expert) {
318
324
int fetch_row = local_row_fetchlist[expert];
319
325
if (fetch_row < 0 ) continue ;
320
- float token_val =
321
- __bfloat162float (unzipped_tokens[fetch_row * token_length + i]);
326
+ aggreg_cnt ++;
327
+ raw = unzipped_tokens[fetch_row * token_length + i];
328
+ float token_val = __bfloat162float (raw);
322
329
sum = __fadd_rn (token_val, sum);
323
330
}
324
- zipped_tokens[this_row * token_length + i] = __float2bfloat16_rn (sum);
331
+ zipped_tokens[this_row * token_length + i] = (aggreg_cnt > 1 )? __float2bfloat16_rn (sum) : raw ;
325
332
}
326
333
} else {
327
334
// ------------------------ BF16 intrinsics 加权累加 -----------------------
@@ -358,6 +365,55 @@ __global__ void tokens_zip_kernel(
358
365
}
359
366
}
360
367
}
368
+ template <int topk, int num_experts>
369
+ __global__ void tokens_zip_kernel (
370
+ const float *__restrict__ unzipped_tokens,
371
+ const int *__restrict__ zipped_expertwise_rowmap,
372
+ const int *__restrict__ expert_routemap_topk,
373
+ const float *__restrict__ unzipped_token_probs,
374
+ float *__restrict__ zipped_tokens,
375
+ float *__restrict__ zipped_probs_topk,
376
+ const int total_zipped_tokens_num,
377
+ const int token_length) {
378
+ const int this_row = blockIdx .x ;
379
+ if (this_row >= total_zipped_tokens_num) return ;
380
+ int local_row_fetchlist[num_experts];
381
+
382
+ // -------------------------初始化任务表 ------------------------
383
+ #pragma unroll
384
+ for (int expert = 0 ; expert < num_experts; ++expert) {
385
+ const int fetch_row =
386
+ zipped_expertwise_rowmap[this_row * num_experts + expert];
387
+ local_row_fetchlist[expert] = fetch_row;
388
+ }
389
+
390
+ #pragma unroll
391
+ for (int k = 0 ; k < topk; ++k) {
392
+ const int expert_idx = expert_routemap_topk[this_row * topk + k];
393
+ if (expert_idx < 0 ) [[likely]]
394
+ continue ;
395
+ const int expert_fetch_row = local_row_fetchlist[expert_idx];
396
+ zipped_probs_topk[this_row * topk + k] =
397
+ unzipped_token_probs[expert_fetch_row];
398
+ }
399
+
400
+ const int thread_stride = blockDim .x ;
401
+
402
+ // ------------------------ 手动混合精度 ---------------------------------
403
+ // 齐整区域向量化搬移
404
+ for (int x_offset = threadIdx .x ; x_offset < token_length;
405
+ x_offset += thread_stride) {
406
+ float sum = 0 .0f ;
407
+ #pragma unroll
408
+ for (int expert = 0 ; expert < num_experts; ++expert) {
409
+ const int fetch_row = local_row_fetchlist[expert];
410
+ if (fetch_row < 0 ) continue ;
411
+ // 手动类型提升
412
+ sum += unzipped_tokens[fetch_row * token_length + x_offset];
413
+ }
414
+ zipped_tokens[this_row * token_length + x_offset] = sum;
415
+ }
416
+ }
361
417
362
418
// ---------------------------- Dispatch ---------------------------------
363
419
void dispatch_tokens_unzip (const paddle::Tensor &X,
@@ -435,17 +491,6 @@ void dispatch_tokens_unzip(const paddle::Tensor &X,
435
491
#undef HANDLE_PROB_TYPE
436
492
}
437
493
438
- /*
439
- dispatch_tokens_zip(unzipped_tokens,
440
- zipped_expertwise_rowmap,
441
- expert_routemap_topk,
442
- unzipped_token_probs,
443
- zipped_tokens,
444
- zipped_probs_topk,
445
- total_zipped_tokens_num,
446
- num_experts,
447
- cols);
448
- */
449
494
void dispatch_tokens_zip (const paddle::Tensor &unzipped_tokens,
450
495
const paddle::Tensor &zipped_expertwise_rowmap,
451
496
const paddle::Tensor &expert_routemap_topk,
@@ -462,15 +507,27 @@ void dispatch_tokens_zip(const paddle::Tensor &unzipped_tokens,
462
507
463
508
// Map data types to C++ types
464
509
if (topk == 8 && num_experts == 4 ) {
465
- tokens_zip_kernel<8 , 4 ><<<grid, block, 0 , unzipped_tokens.stream()>>> (
466
- unzipped_tokens.data <phi::bfloat16>(),
467
- zipped_expertwise_rowmap.data <int >(),
468
- expert_routemap_topk.data <int >(),
469
- unzipped_token_probs.data <float >(),
470
- zipped_tokens.data <phi::bfloat16>(),
471
- zipped_probs_topk.data <float >(),
472
- total_zipped_tokens_num,
473
- token_length);
510
+ if (unzipped_tokens.dtype () == paddle::DataType::BFLOAT16){
511
+ tokens_zip_kernel<8 , 4 ><<<grid, block, 0 , unzipped_tokens.stream()>>> (
512
+ unzipped_tokens.data <phi::bfloat16>(),
513
+ zipped_expertwise_rowmap.data <int >(),
514
+ expert_routemap_topk.data <int >(),
515
+ unzipped_token_probs.data <float >(),
516
+ zipped_tokens.data <phi::bfloat16>(),
517
+ zipped_probs_topk.data <float >(),
518
+ total_zipped_tokens_num,
519
+ token_length);
520
+ }else if (unzipped_tokens.dtype () == paddle::DataType::FLOAT32){
521
+ tokens_zip_kernel<8 , 4 ><<<grid, block, 0 , unzipped_tokens.stream()>>> (
522
+ unzipped_tokens.data <float >(),
523
+ zipped_expertwise_rowmap.data <int >(),
524
+ expert_routemap_topk.data <int >(),
525
+ unzipped_token_probs.data <float >(),
526
+ zipped_tokens.data <float >(),
527
+ zipped_probs_topk.data <float >(),
528
+ total_zipped_tokens_num,
529
+ token_length);
530
+ }
474
531
}
475
532
}
476
533
@@ -538,7 +595,7 @@ std::vector<paddle::Tensor> tokens_zip(
538
595
const paddle::Tensor &unzipped_token_probs,
539
596
const int &total_zipped_tokens_num,
540
597
const int &num_experts) {
541
- PD_CHECK (unzipped_tokens.dtype () == paddle::DataType::BFLOAT16);
598
+ PD_CHECK (unzipped_tokens.dtype () == paddle::DataType::BFLOAT16 || unzipped_tokens. dtype () == paddle::DataType::FLOAT32 );
542
599
const int rows = unzipped_tokens.shape ()[0 ]; // seqlen
543
600
const int cols = unzipped_tokens.shape ()[1 ]; // 一般为7168
544
601
const int topk = expert_routemap_topk.shape ()[1 ]; // 一般为8
0 commit comments