From ae3549e6ee74b5a5428af9ebe5bb5bca1746cd03 Mon Sep 17 00:00:00 2001 From: chen2016013 Date: Mon, 12 May 2025 15:44:46 +0800 Subject: [PATCH 1/2] merge --- .../deep_gemm/include/deep_gemm/fp8_gemm.cuh | 292 +++++-- .../deep_gemm/include/deep_gemm/mma_utils.cuh | 801 ++---------------- .../deep_gemm/include/deep_gemm/scheduler.cuh | 33 +- .../deep_gemm/include/deep_gemm/tma_utils.cuh | 11 +- .../fp8/deep_gemm/include/deep_gemm/utils.cuh | 5 + ops/csrc/fp8/deep_gemm/jit/__init__.py | 4 +- ops/csrc/fp8/deep_gemm/jit/compiler.py | 31 +- ops/csrc/fp8/deep_gemm/jit/interleave_ffma.py | 2 +- ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py | 155 ++-- .../deep_gemm/jit_kernels/m_grouped_gemm.py | 59 +- ops/csrc/fp8/setup.py | 2 +- ops/csrc/fp8/tests/test_core.py | 16 +- 12 files changed, 448 insertions(+), 963 deletions(-) diff --git a/ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh index cc74f21bce49..15786c2c216d 100644 --- a/ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -39,17 +39,34 @@ enum class Layout { ColMajor }; +__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { + return block_m == 64 ? 1 : 2; +} + template __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); - return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; + return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; +} + +template +__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { + if (num_former_iters == kNumFormerIters) { + inner_launch_k_iterations(func, cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); } template __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, @@ -61,15 +78,16 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); // Types using WGMMA = typename FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); // Shared memory static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); @@ -89,7 +107,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + + // `tensor_map_d` is only used in swizzling mode + // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode + if constexpr (kSwizzleDMode > 0) + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); } __syncwarp(); @@ -128,6 +150,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Initialize barriers DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); if (threadIdx.x == kNumMathThreads) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned #pragma unroll for (int i = 0; i < kNumStages; ++ i) { full_barriers[i]->init(1); @@ -145,15 +169,23 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // For pipeline unrolling struct DivisibleK {}; struct NotDivisibleK {}; - auto launch_k_iterations = [](const auto& func) { - if constexpr (SHAPE_K % kFullKOfAllStages == 0) { - for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) - func(k_iter, DivisibleK{}); - } else { - for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}); - func(kNumIterations - 1, NotDivisibleK{}); - } + auto launch_k_iterations = [](const auto& func, int num_former_iters) { + constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // NOTES: for too-many branches (> 5), we disable this optimization + // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value + outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) { + for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) + func(k_iter, DivisibleK{}, num_former_iters_type); + } else { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}, num_former_iters_type); + func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type); + } + }, func, kShouldOptimize ? num_former_iters : 0); }; // Register reconfigurations @@ -162,7 +194,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, grouped_layout); + auto scheduler = Scheduler(shape_m, grouped_layout); if (threadIdx.x >= kNumMathThreads) { // TMA warp-group for loading data @@ -172,28 +204,34 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, if (threadIdx.x == kNumMathThreads) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](int k_iter, auto type) { + launch_k_iterations([&](int k_iter, auto type, auto _) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + // Assign TMA multicast number into A and B + constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1; + constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast; + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant #pragma unroll for (uint32_t s = 0; s < kNumInnerStages; ++ s) { // Wait consumer release empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - // Issue TMA A with broadcasting + // Issue TMA A auto& full_barrier = *full_barriers[s]; int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); - - // Issue TMA B without broadcasting - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); } @@ -203,7 +241,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); full_barriers[s]->arrive(); } - }); + }, 0); } // To safely deconstruct distributed shared barriers, we need another round of empty waits @@ -244,7 +282,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; // Empty barrier arrival auto empty_barrier_arrive = [&](int s) { @@ -256,7 +296,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, }; // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { + launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); @@ -272,42 +312,54 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Wait TMA arrivals full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); - - // Commit WGMMA instructions + // TODO: remove some useless computation for unaligned Ms #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival - empty_barrier_arrive(s); - - // Promote with scales - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } } } @@ -317,34 +369,81 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); empty_barrier_arrive(s); } - }); - - // Write back to shared memory using STSM + }, num_former_iters); + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + DG_STATIC_ASSERT(static_cast(kSwizzleDMode > 0) + static_cast(BLOCK_N_PADDING > 0) <= 1, + "Swizzling and padding are not compatible"); + + // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 - ); + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr int kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + // NOTES: padding must be zero for BF16 output + DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output"); + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } } cute::tma_store_fence(); cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + // TODO: compatible with FP32 output + DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + + // Wait TMA to be finished cute::tma_store_arrive(); cute::tma_store_wait<0>(); } @@ -359,8 +458,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, template class Gemm { private: @@ -380,9 +481,13 @@ public: // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps constexpr uint32_t kNumTMAThreads = 128; constexpr uint32_t kNumMathThreadsPerGroup = 128; - auto kernel = fp8_gemm_kernel; + auto kernel = fp8_gemm_kernel; DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); // Cluster launch @@ -422,10 +527,17 @@ public: template static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { + auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; + if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B; + if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B; + if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B; + + // Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes + // So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required return make_2d_tma_desc(global_address, Layout::RowMajor, shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, - min(BLOCK_M, shape_m), BLOCK_N, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T), + swizzle_mode); } template diff --git a/ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh b/ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh index c1b9dc5148e5..e10a4af150da 100644 --- a/ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -20,742 +20,13 @@ #include +#include +#include + #include "utils.cuh" namespace deep_gemm { -struct SM90_64x16x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 16; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x24x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %14, 0;\n" - "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 24; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x32x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 32; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x40x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %22, 0;\n" - "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 40; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x48x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 48; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x56x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %30, 0;\n" - "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}, " - " %28," - " %29," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 56; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x64x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}, " - " %32," - " %33," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 64; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x72x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %38, 0;\n" - "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}, " - " %36," - " %37," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 72; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x80x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}, " - " %40," - " %41," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 80; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x88x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %46, 0;\n" - "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}, " - " %44," - " %45," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 88; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x96x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}, " - " %48," - " %49," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 96; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x104x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %54, 0;\n" - "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}, " - " %52," - " %53," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 104; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x112x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}, " - " %56," - " %57," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 112; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x120x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, - float& d56, float& d57, float& d58, float& d59, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %62, 0;\n" - "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}, " - " %60," - " %61," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], - d[56], d[57], d[58], d[59], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 120; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x128x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, - float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}, " - " %64," - " %65," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], - d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 128; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x192x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, - float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, - float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, - float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, - float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87, - float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}, " - " %96," - " %97," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], - d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], - d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], - d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], - d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87], - d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 192; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - template struct SM90_U32x2_STSM_N { __device__ __forceinline__ static void @@ -777,15 +48,15 @@ struct SM90_U32x4_STSM_N { } }; -__device__ void warpgroup_arrive() { +__forceinline__ __device__ void warpgroup_arrive() { asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); } -__device__ void warpgroup_commit_batch() { +__forceinline__ __device__ void warpgroup_commit_batch() { asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); } -__device__ void warpgroup_fence_operand(float& reg) { +__forceinline__ __device__ void warpgroup_fence_operand(float& reg) { asm volatile("" : "+f"(reg) :: "memory"); } @@ -876,25 +147,53 @@ __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, return desc; } +template +struct FP8MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, std::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + template struct FP8MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + } + static constexpr auto select_type() { - if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS(); - if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS(); - if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS(); - if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS(); - if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS(); - if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS(); - if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS(); - if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS(); - if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS(); - if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS(); - if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS(); - if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS(); - if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS(); - if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS(); - if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS(); - if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); + return FP8MMA(); } using type = decltype(select_type()); diff --git a/ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh b/ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh index 35cbcf1e3a1e..e7ffa0cc1fd9 100644 --- a/ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh @@ -30,9 +30,10 @@ enum class GemmType { #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" template + uint32_t kNum1DBlocksPerGroup = 16> struct Scheduler { int current_iter = -1; uint32_t num_aligned_m_blocks; @@ -61,16 +62,27 @@ struct Scheduler { } __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { - DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); // Swizzle for better L2 usages - auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; - auto group_idx = block_idx / num_blocks_per_group; - auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; - auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); - auto in_group_idx = block_idx % num_blocks_per_group; - m_block_idx = in_group_idx / num_n_blocks_in_group; - n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; + // TODO: unify these 2 branches + if constexpr (kIsTMAMulticastOnA) { + auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup; + auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = in_group_idx / num_n_blocks_in_group; + n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; + } else { + auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup; + auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group; + n_block_idx = in_group_idx / num_m_blocks_in_group; + } } template @@ -116,6 +128,7 @@ struct Scheduler { return true; } }; + #pragma clang diagnostic pop } // namespace deep_gemm \ No newline at end of file diff --git a/ops/csrc/fp8/deep_gemm/include/deep_gemm/tma_utils.cuh b/ops/csrc/fp8/deep_gemm/include/deep_gemm/tma_utils.cuh index f0a0e4f17623..782be1310e9f 100644 --- a/ops/csrc/fp8/deep_gemm/include/deep_gemm/tma_utils.cuh +++ b/ops/csrc/fp8/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -58,7 +58,7 @@ constexpr CUtensorMapDataType get_CUtensorMapDataType() { } } -PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { +inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { // Get pointer to `cuTensorMapEncodeTiled` cudaDriverEntryPointQueryResult driver_status; void* cuTensorMapEncodeTiled_ptr = nullptr; @@ -81,16 +81,15 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], uint64_t stride_in_bytes, uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled encode_func = nullptr) { - CUtensorMap tensor_map{}; - constexpr uint32_t rank = 2; - uint64_t global_stride[rank - 1] = {stride_in_bytes}; - uint32_t elem_strides[rank] = {1, 1}; + CUtensorMap tensor_map = {}; + uint64_t global_stride[1] = {stride_in_bytes}; + uint32_t elem_strides[2] = {1, 1}; if (encode_func == nullptr) encode_func = get_cuTensorMapEncodeTiled(); auto result = encode_func( - &tensor_map, get_CUtensorMapDataType::type>(), rank, + &tensor_map, get_CUtensorMapDataType>(), 2, global_address, gmem_dim, global_stride, smem_dim, elem_strides, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, diff --git a/ops/csrc/fp8/deep_gemm/include/deep_gemm/utils.cuh b/ops/csrc/fp8/deep_gemm/include/deep_gemm/utils.cuh index c21d16e513c2..71ae2c0541ce 100644 --- a/ops/csrc/fp8/deep_gemm/include/deep_gemm/utils.cuh +++ b/ops/csrc/fp8/deep_gemm/include/deep_gemm/utils.cuh @@ -63,4 +63,9 @@ do { template __device__ __host__ constexpr T ceil_div(T a, T b) { return (a + b - 1) / b; +} + +template +__device__ __host__ constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); } \ No newline at end of file diff --git a/ops/csrc/fp8/deep_gemm/jit/__init__.py b/ops/csrc/fp8/deep_gemm/jit/__init__.py index cb04fd0007f8..34adc145dcf3 100644 --- a/ops/csrc/fp8/deep_gemm/jit/__init__.py +++ b/ops/csrc/fp8/deep_gemm/jit/__init__.py @@ -16,6 +16,6 @@ # Copyright (c) 2025 DeepSeek # Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE -from .compiler import build, get_nvcc_compiler -from .runtime import Runtime +from .compiler import get_nvcc_compiler, build from .template import cpp_format, generate +from .runtime import Runtime diff --git a/ops/csrc/fp8/deep_gemm/jit/compiler.py b/ops/csrc/fp8/deep_gemm/jit/compiler.py index ea6714be1f5a..354cc7789113 100644 --- a/ops/csrc/fp8/deep_gemm/jit/compiler.py +++ b/ops/csrc/fp8/deep_gemm/jit/compiler.py @@ -16,8 +16,8 @@ # Copyright (c) 2025 DeepSeek # Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE -import functools import hashlib +import functools import os import re import subprocess @@ -75,9 +75,7 @@ def get_nvcc_compiler() -> Tuple[str, str]: match = version_pattern.search(os.popen(f"{path} --version").read()) version = match.group(1) assert match, f"Cannot get the version of NVCC compiler {path}" - assert ( - version >= least_version_required - ), f"NVCC {path} version {version} is lower than {least_version_required}" + assert version >= least_version_required, f"NVCC {path} version {version} is lower than {least_version_required}" return path, version raise RuntimeError("Cannot find any available NVCC compiler") @@ -117,18 +115,13 @@ def put(path, data, is_binary=False): def build(name: str, arg_defs: tuple, code: str) -> Runtime: # Compiler flags - nvcc_flags = [ - "-std=c++17", - "-shared", - "-O3", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "-gencode=arch=compute_90a,code=sm_90a", - "--ptxas-options=--register-usage-level=10" + (",--verbose" if "DG_PTXAS_VERBOSE" in os.environ else ""), - # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases - "--diag-suppress=177,174,940", - ] - cxx_flags = ["-fPIC", "-O3", "-Wno-deprecated-declarations", "-Wno-abi"] + cpp_standard = int(os.getenv("DG_NVCC_OVERRIDE_CPP_STANDARD", 20)) + nvcc_flags = [f"-std=c++{cpp_standard}", "-shared", "-O3", "--expt-relaxed-constexpr", "--expt-extended-lambda", + "-gencode=arch=compute_90a,code=sm_90a", + "--ptxas-options=--register-usage-level=10" + (",--verbose" if "DG_PTXAS_VERBOSE" in os.environ else ""), + # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases + "--diag-suppress=39,174,177,940"] + cxx_flags = ["-fPIC", "-O3", "-Wno-deprecated-declarations", "-Wno-abi", "-fconcepts"] flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] include_dirs = [get_jit_include_dir()] @@ -155,8 +148,12 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime: # Compile into a temporary SO file so_path = f"{path}/kernel.so" tmp_so_path = f"{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so" + # Compile - command = [get_nvcc_compiler()[0], src_path, "-o", tmp_so_path, *flags, *[f"-I{d}" for d in include_dirs]] + command = [get_nvcc_compiler()[0], + src_path, "-o", tmp_so_path, + *flags, + *[f"-I{d}" for d in include_dirs]] if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_JIT_PRINT_NVCC_COMMAND", False): print(f"Compiling JIT runtime {name} with command {command}") return_code = subprocess.check_call(command) diff --git a/ops/csrc/fp8/deep_gemm/jit/interleave_ffma.py b/ops/csrc/fp8/deep_gemm/jit/interleave_ffma.py index 0a5919b6b87b..53a5ce8604eb 100644 --- a/ops/csrc/fp8/deep_gemm/jit/interleave_ffma.py +++ b/ops/csrc/fp8/deep_gemm/jit/interleave_ffma.py @@ -93,7 +93,7 @@ def parse_registers(line): def modify_segment(m, name, ffma_lines): - num_lines = len(ffma_lines) + num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2 assert num_lines % 2 == 0 le_bytes, new_le_bytes = [], [] diff --git a/ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py b/ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py index 3a674f71129a..35106c6bc134 100644 --- a/ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py +++ b/ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py @@ -16,22 +16,16 @@ # Copyright (c) 2025 DeepSeek # Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE -import functools +import math +from functools import lru_cache from typing import Tuple - import paddle from paddle import Tensor - from .tuner import jit_tuner -from .utils import ( - ceil_div, - get_col_major_tma_aligned_tensor, - get_m_alignment_for_contiguous_layout, - get_num_sms, -) +from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout # C++ code templates -includes = ('"deep_gemm/fp8_gemm.cuh"',) +includes = ('"deep_gemm/fp8_gemm.cuh"', ) template = """ using namespace deep_gemm; @@ -39,32 +33,58 @@ constexpr auto N = {N}, K = {K}; constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; +constexpr auto BLOCK_K = 128; +constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; +constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE}; +constexpr auto kNumGroups = 1; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; +constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated GEMM -using GemmType = Gemm; +using gemm_t = Gemm; // Launch kernel -auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); -GemmType::run(out, rhs_scales, nullptr, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); +auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); +auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); +auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); +auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m); +gemm_t::run(out, rhs_scales, nullptr, + m, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, + stream, num_sms, smem_size); """ -def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool: +def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool: if num_tma_multicast == 1: return True - return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 + return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 + + +def get_swizzle_mode(block_n: int) -> int: + # TODO: remove some candidates if slow + elem_size = 2 + for mode_bytes in (128, 64, 32): + if (block_n * elem_size) % mode_bytes == 0: + return mode_bytes + return 0 + + +def get_block_n_padding_for_smem_d(block_n: int) -> int: + # NOTES: padding is for solving bank conflicts, but wastes shared memory space + elem_size, requirement = 2, (4, 8) + bank_stride = (block_n * elem_size) // 4 + padding = (requirement[0] - bank_stride) % requirement[1] + return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size -def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: - smem_d = block_m * block_n * 2 +def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]: + # Try swizzle first, as it does not waste shared memory + swizzle_mode = get_swizzle_mode(block_n) + block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0 + + smem_d = block_m * (block_n + block_n_padding) * 2 smem_a_per_stage = block_m * block_k smem_scales_a_per_stage = block_m * 4 smem_b_per_stage = block_n * block_k @@ -78,18 +98,22 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: smem_size += num_stages * smem_b_per_stage smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 smem_size += smem_barrier - return smem_size + + # Swizzle and padding are not compatible + assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 + + return smem_size, swizzle_mode, block_n_padding -def get_best_configs( - m: int, n: int, k: int, num_groups: int, num_sms: int, is_grouped_contiguous: bool = False -) -> Tuple[int, int, int, int, int]: +@lru_cache(maxsize=None) +def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, + is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \ + Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]: if not is_grouped_contiguous: - # TODO: for some cases, smaller M block is better, add them into tuning space - block_ms = (64 if m <= 64 else 128,) + block_ms = (64, 128, 256) else: - block_ms = (get_m_alignment_for_contiguous_layout(),) - block_ns = tuple(range(16, 129, 8)) + block_ms = (get_m_alignment_for_contiguous_layout(), ) + block_ns = tuple(range(16, 129, 8)) + (144, 160, ) fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) @@ -98,7 +122,9 @@ def get_best_configs( # Decide block sizes by waves best_block_m, best_block_n = None, None for block_m in block_ms: - for block_n in block_ns: + # NOTES: the block sizes can not be too large, so at least one dim less than 128 + for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): + success = False num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) if best_block_m is None or best_block_n is None: @@ -109,49 +135,76 @@ def get_best_configs( # Check last wave utilization util = get_last_wave_util(block_m, block_n) best_util = get_last_wave_util(best_block_m, best_block_n) - success = util > best_util or ( - util == best_util - and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n)) - ) + success = util > best_util + if util == best_util: + # Case 1: same `block_m`, smaller `block_n` (wasted) + success |= block_m == best_block_m and block_n < best_block_n + # Case 2: same `block_n`, smaller `block_m` (wasted) + success |= block_n == best_block_n and block_m < best_block_m + # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better + success |= block_m != best_block_m and block_n > best_block_n best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) assert best_block_m is not None and best_block_n is not None # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced - best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 - for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): - best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) - if best_smem_size <= sm90_capacity: + best_num_stages, best_smem_config, sm90_capacity = None, None, 232448 + stage_candidates = (8, 7, 6, 5, 4, 3) + if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: + # Unrolling both stages and `num_former_iters` will cause large code size + stage_candidates = (4, 3) + for num_stages in stage_candidates: + best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n) + if best_smem_config[0] <= sm90_capacity: best_num_stages = num_stages break + assert best_smem_config is not None assert best_num_stages is not None - # Decide the number of TMA multicast - best_num_tma_multicast = 1 - if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: - best_num_tma_multicast = 2 + # Decide the number of TMA multicast and whether broadcast on A + best_tma_multicast_config = (1, 1) + + # Try to multicast on the larger block side first + is_dense_gemm = (not is_grouped_contiguous) and (not is_grouped_masked) + is_multicast_legal = { + 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms), + 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and is_dense_gemm, + } + for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): + if m >= 512 and is_multicast_legal[i]: + best_tma_multicast_config = (2, int(i == 'A')) + break - return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size + # Recompute the minimal number of SMs required + # NOTES: less L2 cache usage and less GPU frequency drop + num_waves = get_num_waves(best_block_m, best_block_n) + num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) + num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] + assert num_min_sms <= num_sms + return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config -@functools.lru_cache() +@lru_cache() def auto_tuning_with_compilation(m, n, k, num_sms): global includes, template if num_sms is None: num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) runtime = jit_tuner.compile_and_tune( m, n, k, name="gemm_fp8_fp8_bf16_nt", keys={ + 'SWIZZLE_D_MODE': smem_config[1], + 'BLOCK_N_PADDING': smem_config[2], "BLOCK_M": block_m, "BLOCK_N": block_n, "K": k, "N": n, "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": num_tma_multicast, + "NUM_TMA_MULTICAST": tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], }, space=(), includes=includes, @@ -168,7 +221,7 @@ def auto_tuning_with_compilation(m, n, k, num_sms): ), template=template, ) - return runtime, num_sms, smem_size + return runtime, num_sms, smem_config def gemm_fp8_fp8_bf16_nt(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, num_sms=112) -> None: @@ -211,7 +264,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], # Do nothing if `m` is zero if m == 0: return - runtime, num_sms, smem_size = auto_tuning_with_compilation(m, n, k, num_sms) - args = (lhs, lhs_scales, rhs, rhs_scales, out, m, paddle.device.current_stream().stream_base, num_sms, smem_size) + runtime, num_sms, smem_config = auto_tuning_with_compilation(m, n, k, num_sms) + args = (lhs, lhs_scales, rhs, rhs_scales, out, m, paddle.device.current_stream().stream_base, num_sms, smem_config[0]) # Run the kernel. runtime(*args) diff --git a/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py b/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py index fef0505be553..58ff7b833874 100644 --- a/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -22,13 +22,12 @@ import paddle from paddle import Tensor - -from .gemm import get_best_configs +from .gemm import get_best_configs, get_block_n_padding_for_smem_d from .tuner import jit_tuner from .utils import get_col_major_tma_aligned_tensor, get_num_sms # C++ code templates -includes = ('"deep_gemm/fp8_gemm.cuh"',) +includes = ('"deep_gemm/fp8_gemm.cuh"', ) template = """ using namespace deep_gemm; @@ -36,21 +35,26 @@ constexpr auto N = {N}, K = {K}; constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; +constexpr auto BLOCK_K = 128; +constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; +constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE}; +constexpr auto kNumGroups = {NUM_GROUPS}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; +constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated grouped GEMM -using GemmType = Gemm; +using gemm_t = Gemm; // Launch kernel -auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); -GemmType::run(out, rhs_scales, grouped_layout, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); +auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); +auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); +auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); +auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m); +gemm_t::run(out, rhs_scales, grouped_layout, + m, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, + stream, num_sms, smem_size); """ @@ -59,9 +63,7 @@ def auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, nu global includes, template if num_sms is None: num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs( - m, n, k, 1, num_sms, is_grouped_contiguous=True - ) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) runtime = jit_tuner.compile_and_tune( m, n, @@ -70,12 +72,15 @@ def auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, nu keys={ "BLOCK_M": block_m, "BLOCK_N": block_n, + "SWIZZLE_D_MODE": smem_config[1], + "BLOCK_N_PADDING": smem_config[2], "GEMM_TYPE": "GroupedContiguous", "K": k, "N": n, "NUM_GROUPS": num_groups, "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": num_tma_multicast, + "NUM_TMA_MULTICAST": tma_multicast_config[0], + "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], }, space=(), includes=includes, @@ -94,7 +99,7 @@ def auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, nu ), template=template, ) - return runtime, num_sms, smem_size + return runtime, num_sms, smem_config def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -148,7 +153,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( return # Auto-tuning with compilation global includes, template - runtime, num_sms, smem_size = auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, num_sms) + runtime, num_sms, smem_config = auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, num_sms) args = ( lhs, @@ -161,7 +166,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( num_groups, paddle.device.current_stream().stream_base, num_sms, - smem_size, + smem_config[0], ) runtime(*args) @@ -172,12 +177,14 @@ def auto_tuning_with_compilation_grouped_gemm_masked(m, expected_m, n, k, num_gr global includes, template if num_sms is None: num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs( + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( expected_m, n, k, num_groups, num_sms ) # Extra checks for TMA store if num_groups > 1 and m > block_m: + while m % block_m != 0 and block_m > 128: + block_m = block_m // 2 assert ( m % block_m == 0 ), f"For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})" @@ -189,9 +196,12 @@ def auto_tuning_with_compilation_grouped_gemm_masked(m, expected_m, n, k, num_gr "K": k, "BLOCK_M": block_m, "BLOCK_N": block_n, + 'SWIZZLE_D_MODE': smem_config[1], + 'BLOCK_N_PADDING': smem_config[2], "NUM_GROUPS": num_groups, "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": num_tma_multicast, + "NUM_TMA_MULTICAST": tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], "GEMM_TYPE": "GroupedMasked", }, space=(), @@ -211,7 +221,7 @@ def auto_tuning_with_compilation_grouped_gemm_masked(m, expected_m, n, k, num_gr template=template, ) - return runtime, num_sms, smem_size + return runtime, num_sms, smem_config def m_grouped_gemm_fp8_fp8_bf16_nt_masked( @@ -261,8 +271,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked( # LHS scales must be transposed for TMA load, but not for RHS scales lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) assert rhs_scales.is_contiguous() - - runtime, num_sms, smem_size = auto_tuning_with_compilation_grouped_gemm_masked( + runtime, num_sms, smem_config = auto_tuning_with_compilation_grouped_gemm_masked( m, expected_m, n, k, num_groups, num_sms ) @@ -276,7 +285,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked( m, paddle.device.current_stream().stream_base, num_sms, - smem_size, + smem_config[0], ) # Run the kernel diff --git a/ops/csrc/fp8/setup.py b/ops/csrc/fp8/setup.py index 019b28921a8a..7678456bae9f 100644 --- a/ops/csrc/fp8/setup.py +++ b/ops/csrc/fp8/setup.py @@ -92,7 +92,7 @@ def prepare_includes(self): packages=["deep_gemm", "deep_gemm/jit", "deep_gemm/jit_kernels"], package_data={ "deep_gemm": [ - "include/deep_gemm/**/*", + "include/deep_gemm/*", "include/cute/**/*", "include/cutlass/**/*", ] diff --git a/ops/csrc/fp8/tests/test_core.py b/ops/csrc/fp8/tests/test_core.py index b10616dac2b5..fe5d4a746561 100644 --- a/ops/csrc/fp8/tests/test_core.py +++ b/ops/csrc/fp8/tests/test_core.py @@ -114,10 +114,8 @@ def construct_grouped( def test_gemm() -> None: print("Testing GEMM:") - for m in (64,): - for k, n in [ - (7168, 2112), - ]: + for m in (64, 128, 4096): + for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: x_fp8, y_fp8, out, ref_out = construct(m, k, n) deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) diff = calc_diff(out, ref_out) @@ -129,7 +127,7 @@ def test_gemm() -> None: def test_m_grouped_gemm_contiguous() -> None: print("Testing grouped contiguous GEMM:") - for num_groups, m, k, n in ((4, 8192, 7168, 4096),): + for num_groups, m, k, n in ((8, 4096, 7168, 4096), (8, 4096, 2048, 7168), (4, 8192, 2048, 7168), (4, 8192, 7168, 4096), ): # TODO: make a stronger test x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) m_indices = paddle.arange(0, num_groups, dtype=paddle.int32) @@ -144,9 +142,9 @@ def test_m_grouped_gemm_contiguous() -> None: def test_m_grouped_gemm_masked() -> None: print("Testing grouped masked GEMM:") - - for num_groups, m in ((1, 1024),): - for k, n in ((7168, 4096),): + + for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + for k, n in ((7168, 4096), (2048, 7168), ): # Test correctness masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))) for i in range(10): @@ -158,7 +156,7 @@ def test_m_grouped_gemm_masked() -> None: masked_m_float = paddle.cast(masked_m, "float32") masked_m_mean = paddle.mean(masked_m_float) masked_m_mean_int = paddle.cast(masked_m_mean, "int32") - expected_m = min(masked_m_mean_int + 1, m) + expected_m = min(int(masked_m_mean_int + 1), m) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m) for j in range(num_groups): diff = calc_diff(out[j, : masked_m[j].item()], ref_out[j, : masked_m[j].item()]) From b0a3b9d757e03a1729f466c5bef3809dae34d075 Mon Sep 17 00:00:00 2001 From: chen2016013 <111894720+chen2016013@users.noreply.github.com> Date: Wed, 14 May 2025 16:09:31 +0800 Subject: [PATCH 2/2] Update m_grouped_gemm.py --- ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py b/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py index 58ff7b833874..7cbf8ce64ff4 100644 --- a/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -178,7 +178,7 @@ def auto_tuning_with_compilation_grouped_gemm_masked(m, expected_m, n, k, num_gr if num_sms is None: num_sms = get_num_sms() num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( - expected_m, n, k, num_groups, num_sms + expected_m, n, k, num_groups, num_sms, is_grouped_masked=True ) # Extra checks for TMA store