@@ -438,6 +438,15 @@ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
438
438
}
439
439
};
440
440
441
+ template <int lut>
442
+ __device__ inline int lop3 (int a, int b, int c) {
443
+ int res;
444
+ asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
445
+ : " =r" (res)
446
+ : " r" (a), " r" (b), " r" (c), " n" (lut));
447
+ return res;
448
+ }
449
+
441
450
template <typename T>
442
451
struct FastInterleavedAndBiasedNumericArrayConverter <T, uint2b_t , 16 >
443
452
{
@@ -458,24 +467,84 @@ struct FastInterleavedAndBiasedNumericArrayConverter<T, uint2b_t, 16>
458
467
result_type result;
459
468
uint8_t const * in_ptr = reinterpret_cast <uint8_t const *>(&source);
460
469
461
- CUTLASS_PRAGMA_UNROLL
462
- for (int i = 0 ; i < 4 ; ++i) {
463
- int32_t decode_value =
464
- static_cast <int32_t >(floor (static_cast <ScaleComputeT>(in_ptr[i]) * code_scale + code_zp + 0 .5f ));
465
-
466
- ScaleComputeT value_3 = static_cast <ScaleComputeT>((decode_value & kWeightMask ) - kBZP );
467
- decode_value >>= 3 ;
468
- ScaleComputeT value_2 = static_cast <ScaleComputeT>((decode_value & kWeightMask ) - kBZP );
469
- decode_value >>= 3 ;
470
- ScaleComputeT value_1 = static_cast <ScaleComputeT>((decode_value & kWeightMask ) - kBZP );
471
- decode_value >>= 3 ;
472
- ScaleComputeT value_0 = static_cast <ScaleComputeT>((decode_value & kWeightMask ) - kBZP );
473
-
474
- result[i * 4 ] = static_cast <T>(value_0);
475
- result[i * 4 + 1 ] = static_cast <T>(value_1);
476
- result[i * 4 + 2 ] = static_cast <T>(value_2);
477
- result[i * 4 + 3 ] = static_cast <T>(value_3);
478
- }
470
+ int32_t decode_value0 =
471
+ static_cast <int32_t >(floor (static_cast <ScaleComputeT>(in_ptr[0 ]) * code_scale + code_zp + 0 .5f ));
472
+ int32_t decode_value1 =
473
+ static_cast <int32_t >(floor (static_cast <ScaleComputeT>(in_ptr[1 ]) * code_scale + code_zp + 0 .5f ));
474
+ int32_t decode_value2 =
475
+ static_cast <int32_t >(floor (static_cast <ScaleComputeT>(in_ptr[2 ]) * code_scale + code_zp + 0 .5f ));
476
+ int32_t decode_value3 =
477
+ static_cast <int32_t >(floor (static_cast <ScaleComputeT>(in_ptr[3 ]) * code_scale + code_zp + 0 .5f ));
478
+
479
+ static constexpr uint32_t MASK = 0x003F003F ;
480
+ static constexpr uint32_t EX = 0x43004300 ;
481
+ uint32_t * h = reinterpret_cast <uint32_t *>(&result);
482
+ int32_t q;
483
+
484
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16))
485
+
486
+ static constexpr uint32_t SUB = 0x43204320 ;
487
+
488
+ q = (decode_value1 << 16 ) | (decode_value0 & 0xFFFF );
489
+ int lo3 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
490
+ q >>= 3 ;
491
+ int lo2 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
492
+ q >>= 3 ;
493
+ int lo1 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
494
+ q >>= 3 ;
495
+ int lo0 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
496
+
497
+ asm volatile (" sub.bf16x2 %0, %1, %2;\n " : " =r" (h[0 ]) : " r" (lo0), " r" (SUB));
498
+ asm volatile (" sub.bf16x2 %0, %1, %2;\n " : " =r" (h[1 ]) : " r" (lo1), " r" (SUB));
499
+ asm volatile (" sub.bf16x2 %0, %1, %2;\n " : " =r" (h[2 ]) : " r" (lo2), " r" (SUB));
500
+ asm volatile (" sub.bf16x2 %0, %1, %2;\n " : " =r" (h[3 ]) : " r" (lo3), " r" (SUB));
501
+
502
+ q = (decode_value3 << 16 ) | (decode_value2 & 0xFFFF );
503
+ lo3 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
504
+ q >>= 3 ;
505
+ lo2 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
506
+ q >>= 3 ;
507
+ lo1 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
508
+ q >>= 3 ;
509
+ lo0 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
510
+
511
+ asm volatile (" sub.bf16x2 %0, %1, %2;\n " : " =r" (h[4 ]) : " r" (lo0), " r" (SUB));
512
+ asm volatile (" sub.bf16x2 %0, %1, %2;\n " : " =r" (h[5 ]) : " r" (lo1), " r" (SUB));
513
+ asm volatile (" sub.bf16x2 %0, %1, %2;\n " : " =r" (h[6 ]) : " r" (lo2), " r" (SUB));
514
+ asm volatile (" sub.bf16x2 %0, %1, %2;\n " : " =r" (h[7 ]) : " r" (lo3), " r" (SUB));
515
+ #else
516
+
517
+ static constexpr uint32_t MUL = 0x3F803F80 ;
518
+ static constexpr uint32_t ADD = 0xC320C320 ;
519
+
520
+ q = (decode_value1 << 16 ) | (decode_value0 & 0xFFFF );
521
+ int lo3 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
522
+ q >>= 3 ;
523
+ int lo2 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
524
+ q >>= 3 ;
525
+ int lo1 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
526
+ q >>= 3 ;
527
+ int lo0 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
528
+
529
+ asm volatile (" fma.rn.bf16x2 %0, %1, %2, %3;\n " : " =r" (h[0 ]) : " r" (lo0), " r" (MUL), " r" (ADD));
530
+ asm volatile (" fma.rn.bf16x2 %0, %1, %2, %3;\n " : " =r" (h[1 ]) : " r" (lo1), " r" (MUL), " r" (ADD));
531
+ asm volatile (" fma.rn.bf16x2 %0, %1, %2, %3;\n " : " =r" (h[2 ]) : " r" (lo2), " r" (MUL), " r" (ADD));
532
+ asm volatile (" fma.rn.bf16x2 %0, %1, %2, %3;\n " : " =r" (h[3 ]) : " r" (lo3), " r" (MUL), " r" (ADD));
533
+
534
+ q = (decode_value3 << 16 ) | (decode_value2 & 0xFFFF );
535
+ lo3 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
536
+ q >>= 3 ;
537
+ lo2 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
538
+ q >>= 3 ;
539
+ lo1 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
540
+ q >>= 3 ;
541
+ lo0 = lop3<(0xf0 & 0xcc ) | 0xaa >(q, MASK, EX);
542
+
543
+ asm volatile (" fma.rn.bf16x2 %0, %1, %2, %3;\n " : " =r" (h[4 ]) : " r" (lo0), " r" (MUL), " r" (ADD));
544
+ asm volatile (" fma.rn.bf16x2 %0, %1, %2, %3;\n " : " =r" (h[5 ]) : " r" (lo1), " r" (MUL), " r" (ADD));
545
+ asm volatile (" fma.rn.bf16x2 %0, %1, %2, %3;\n " : " =r" (h[6 ]) : " r" (lo2), " r" (MUL), " r" (ADD));
546
+ asm volatile (" fma.rn.bf16x2 %0, %1, %2, %3;\n " : " =r" (h[7 ]) : " r" (lo3), " r" (MUL), " r" (ADD));
547
+ #endif
479
548
return result;
480
549
}
481
550
0 commit comments