Skip to content

Commit 85e337a

Browse files
committed
optimizing dequant performance with LOP3
1 parent 3caf310 commit 85e337a

File tree

2 files changed

+92
-23
lines changed

2 files changed

+92
-23
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,21 +284,21 @@ class MmaTensorOpWin2xDequantizer<
284284
static_cast<int>(sizeof(FragmentCompute)));
285285
}
286286
#endif
287-
288287
int offset = warp_k_compute_offset * ArchMmaOperator::FragmentB::kElements;
289-
const int kOutputColumns = FragmentOutput::kElements / kWarpIterationsAlongN;
288+
int mapped_offset = (warp_k_compute_offset % 2) == 0 ? 0 : (-kOutputColumns + 1);
290289

291290
CUTLASS_PRAGMA_UNROLL
292-
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
291+
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
293292

294293
CUTLASS_PRAGMA_UNROLL
295294
for (int j = 0; j < kOutputColumns; ++j) {
295+
// After applying LOP3 optimizations for performance, the B operand requires data rearrangement.
296+
int mapped_idx = mma_n_iter * kExpansionFactor * kOutputColumns + offset + 2 * j + mapped_offset;
296297
ElementCompute scaled_value =
297-
static_cast<ElementCompute>(unpacked_frag_[mma_n_iter * kExpansionFactor * kOutputColumns + offset + j]) * scale_frag[mma_n_iter];
298+
static_cast<ElementCompute>(unpacked_frag_[mapped_idx]) * scale_frag[mma_n_iter];
298299
output_frag[mma_n_iter * kOutputColumns + j] = static_cast<ElementOperand>(scaled_value);
299300
}
300301
}
301-
302302
#if 0
303303
if (FragmentOutput::kElements == 16) {
304304
CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",

custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,15 @@ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
438438
}
439439
};
440440

441+
template <int lut>
442+
__device__ inline int lop3(int a, int b, int c) {
443+
int res;
444+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
445+
: "=r"(res)
446+
: "r"(a), "r"(b), "r"(c), "n"(lut));
447+
return res;
448+
}
449+
441450
template <typename T>
442451
struct FastInterleavedAndBiasedNumericArrayConverter<T, uint2b_t, 16>
443452
{
@@ -458,24 +467,84 @@ struct FastInterleavedAndBiasedNumericArrayConverter<T, uint2b_t, 16>
458467
result_type result;
459468
uint8_t const* in_ptr = reinterpret_cast<uint8_t const*>(&source);
460469

461-
CUTLASS_PRAGMA_UNROLL
462-
for (int i = 0; i < 4; ++i) {
463-
int32_t decode_value =
464-
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(in_ptr[i]) * code_scale + code_zp + 0.5f));
465-
466-
ScaleComputeT value_3 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
467-
decode_value >>= 3;
468-
ScaleComputeT value_2 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
469-
decode_value >>= 3;
470-
ScaleComputeT value_1 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
471-
decode_value >>= 3;
472-
ScaleComputeT value_0 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
473-
474-
result[i * 4] = static_cast<T>(value_0);
475-
result[i * 4 + 1] = static_cast<T>(value_1);
476-
result[i * 4 + 2] = static_cast<T>(value_2);
477-
result[i * 4 + 3] = static_cast<T>(value_3);
478-
}
470+
int32_t decode_value0 =
471+
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(in_ptr[0]) * code_scale + code_zp + 0.5f));
472+
int32_t decode_value1 =
473+
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(in_ptr[1]) * code_scale + code_zp + 0.5f));
474+
int32_t decode_value2 =
475+
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(in_ptr[2]) * code_scale + code_zp + 0.5f));
476+
int32_t decode_value3 =
477+
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(in_ptr[3]) * code_scale + code_zp + 0.5f));
478+
479+
static constexpr uint32_t MASK = 0x003F003F;
480+
static constexpr uint32_t EX = 0x43004300;
481+
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
482+
int32_t q;
483+
484+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16))
485+
486+
static constexpr uint32_t SUB = 0x43204320;
487+
488+
q = (decode_value1 << 16) | (decode_value0 & 0xFFFF);
489+
int lo3 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
490+
q >>= 3;
491+
int lo2 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
492+
q >>= 3;
493+
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
494+
q >>= 3;
495+
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
496+
497+
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(lo0), "r"(SUB));
498+
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(lo1), "r"(SUB));
499+
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(lo2), "r"(SUB));
500+
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(lo3), "r"(SUB));
501+
502+
q = (decode_value3 << 16) | (decode_value2 & 0xFFFF);
503+
lo3 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
504+
q >>= 3;
505+
lo2 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
506+
q >>= 3;
507+
lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
508+
q >>= 3;
509+
lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
510+
511+
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(lo0), "r"(SUB));
512+
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(lo1), "r"(SUB));
513+
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(lo2), "r"(SUB));
514+
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(lo3), "r"(SUB));
515+
#else
516+
517+
static constexpr uint32_t MUL = 0x3F803F80;
518+
static constexpr uint32_t ADD = 0xC320C320;
519+
520+
q = (decode_value1 << 16) | (decode_value0 & 0xFFFF);
521+
int lo3 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
522+
q >>= 3;
523+
int lo2 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
524+
q >>= 3;
525+
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
526+
q >>= 3;
527+
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
528+
529+
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(lo0), "r"(MUL), "r"(ADD));
530+
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(lo1), "r"(MUL), "r"(ADD));
531+
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(lo2), "r"(MUL), "r"(ADD));
532+
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(lo3), "r"(MUL), "r"(ADD));
533+
534+
q = (decode_value3 << 16) | (decode_value2 & 0xFFFF);
535+
lo3 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
536+
q >>= 3;
537+
lo2 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
538+
q >>= 3;
539+
lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
540+
q >>= 3;
541+
lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
542+
543+
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[4]) : "r"(lo0), "r"(MUL), "r"(ADD));
544+
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(lo1), "r"(MUL), "r"(ADD));
545+
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(lo2), "r"(MUL), "r"(ADD));
546+
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(lo3), "r"(MUL), "r"(ADD));
547+
#endif
479548
return result;
480549
}
481550

0 commit comments

Comments
 (0)