diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 8e6b2996..f3af4c55 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -5,6 +5,7 @@ gemm_fp8_fp8_bf16_nt, m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_masked, + m_grouped_gemm_fp8_fp8_bf16_nt_offset, wgrad_gemm_fp8_fp8_fp32_nt, k_grouped_wgrad_gemm_fp8_fp8_fp32_nt, ceil_div, diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 5c11cd3d..d53eaa0a 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -439,6 +439,809 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, #endif } +template +static __device__ __forceinline__ void write_result_to_gmem(__nv_bfloat16* gmem_d_this_block, + __nv_bfloat16 const* smem_d, uint32_t const m_offset, uint32_t const m_boundary, uint32_t const n_offset, + uint32_t const shape_n, uint32_t const ld_output) +{ + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + constexpr int int4_per_tile_line = BLOCK_N * sizeof(__nv_bfloat16) / sizeof(int4); + int int4_per_global_line = shape_n * sizeof(__nv_bfloat16) / sizeof(int4); + constexpr auto num_lines = BLOCK_M; + constexpr auto num_warps = NUM_WARPS_PER_BLOCK; + int4 const* smem_d_int4 = reinterpret_cast(smem_d); + bool is_last_n_block = n_offset + BLOCK_N > shape_n; + int int4_per_line = is_last_n_block ? int4_per_global_line % int4_per_tile_line : int4_per_tile_line; + + for (int line_idx = warp_idx; line_idx < num_lines; line_idx += num_warps) + { + if (m_offset + line_idx >= m_boundary) + { + break; + } + for (int elem_idx = lane_idx; elem_idx < int4_per_line; elem_idx += 32) + { + uint64_t idx = (uint64_t) line_idx * ld_output + n_offset; + int4* g_data_addr = reinterpret_cast(&gmem_d_this_block[idx]) + elem_idx; + int4 const* s_data_addr = &smem_d_int4[line_idx * (int4_per_tile_line) + elem_idx]; + *g_data_addr = *s_data_addr; + } + __syncwarp(); + } +} + +template +__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) + fp8_gemm_offset_kernel(__nv_bfloat16* gmem_d, float* scales_b, int64_t* offsets, + __grid_constant__ const CUtensorMap tensor_map_a, __grid_constant__ const CUtensorMap tensor_map_b, + __grid_constant__ const CUtensorMap tensor_map_scales_a, __grid_constant__ const CUtensorMap tensor_map_d) +{ +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ == 900)) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + + InputType problem_input; + problem_input.problem_m_offsets = offsets; + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Shared memory + static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); + static constexpr uint32_t SMEM_SCALES_B_SIZE + = ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) + * sizeof(Barrier); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); + constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; + constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); + uint32_t const warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + uint32_t const lane_idx = get_lane_id(); + + // Prefetch TMA descriptors at very beginning + if (threadIdx.x == kNumMathThreads) + { + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_scales_a[kNumStages]; + float* smem_scales_b; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + +// Fill shared memory pointers +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>( + smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE); + } + smem_scales_b = reinterpret_cast(smem_buffer + SMEM_D_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b) + SMEM_SCALES_B_SIZE); +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) + { +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK + { + }; + + struct NotDivisibleK + { + }; + + auto launch_k_iterations = [](auto const& func) + { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) + { + for (int k_iter = 0; k_iter < kNumIterations; ++k_iter) + func(k_iter, DivisibleK{}); + } + else + { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++k_iter) + func(k_iter, DivisibleK{}); + func(kNumIterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr int kNumTMARegisters = 40; + constexpr int kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = SchedulerType(problem_input); + + if (threadIdx.x >= kNumMathThreads) + { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) + { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) + { + launch_k_iterations( + [&](int k_iter, auto type) + { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages + = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++s) + { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + // Issue TMA A with broadcasting + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_m_idx(m_block_idx), kNumTMAMulticast); + + tma_copy(&tensor_map_scales_a, + reinterpret_cast(&full_barrier), smem_scales_a[s], + scheduler.get_global_scales_a_idx(m_block_idx), k_idx / BLOCK_K, kNumTMAMulticast); + + // Issue TMA B without broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), smem_b[s], k_idx, + scheduler.get_global_n_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx), 1); + full_barrier.arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) + { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) + { +#pragma unroll + for (uint32_t s = 0; s < kNumStages; ++s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } + else + { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + auto const math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + auto const r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) + { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) + { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) + { + auto num_previous_lines + = scheduler.get_global_scales_b_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + ; + auto local_scales_b + = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; +#pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) + { + if constexpr (kNumTMAMulticast == 1) + { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } + else + { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs + launch_k_iterations( + [&](int k_iter, auto type) + { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages + = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (int s = 0; s < kNumInnerStages; ++s) + { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1 = 1.0f; + // NOTES: even some blocks do not need to read the second row, but we still load one to align + // with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled + // block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), + scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + +// Commit WGMMA instructions +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); +#pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++k) + { + auto desc_a + = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + + // Promote with scales + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + { + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) + { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); +#pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++i) + { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) + { + SM90_U32x2_STSM_N::copy(__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], + final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn( + {final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16); + } + + auto m_global_idx = scheduler.get_global_m_idx(m_block_idx); + bool cross_boundary = (m_global_idx + BLOCK_M) > scheduler.m_boundary; + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + if (!cross_boundary) + { + // Use TMA store to write back to global memory + if (threadIdx.x == 0) + { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_global_idx); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + } + else + { + __nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N; + constexpr int NUM_WARPS + = (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem(gmem_d_this_block, smem_d, m_global_idx, + scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +template +__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) + fp8_gemm_offset_kernel_swapAB(__nv_bfloat16* gmem_d, float* scales_a, int64_t* offsets, + const __grid_constant__ CUtensorMap tensor_map_a, // weight (previously act) + const __grid_constant__ CUtensorMap tensor_map_b, // act (previously weight) + const __grid_constant__ CUtensorMap tensor_map_scales_b, // act scales (previously tensor_map_scales_a) + const __grid_constant__ CUtensorMap tensor_map_d) +{ +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(ceil_div(BLOCK_M, BLOCK_K) == 1, "Too much A scales in a single block"); + + InputType problem_input; + problem_input.problem_n_offsets = offsets; + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Shared memory + DG_STATIC_ASSERT(BLOCK_K % BLOCK_M == 0, "BLOCK_M should be 64 or 128 and BLOCK_K should be 128"); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_N * BLOCK_M * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float); // B matrix (act) scales + static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE_PADDED + = ceil_div(BLOCK_N * sizeof(float), 128) * 128; // B matrix (act) scales, 128B aligned + static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); + static constexpr uint32_t SMEM_SCALES_A_SIZE = ceil_div(SHAPE_K_SCALES * sizeof(float), sizeof(Barrier)) + * sizeof(Barrier); // renamed to A (weight) + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); + constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; + constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_id(); + + // Prefetch TMA descriptors at very beginning + if (threadIdx.x == kNumMathThreads) + { + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; // weight + __nv_fp8_e4m3* smem_b[kNumStages]; // act + float* smem_scales_b[kNumStages]; // act scales + float* smem_scales_a; // weight scales + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + +// Fill shared memory pointers +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>( + smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_b[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_B_SIZE_PER_STAGE_PADDED); + } + smem_scales_a = reinterpret_cast(smem_buffer + SMEM_D_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE_PADDED)); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_a) + SMEM_SCALES_A_SIZE); +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) + { +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK + { + }; + + struct NotDivisibleK + { + }; + + auto launch_k_iterations = [](auto const& func) + { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) + { + for (int k_iter = 0; k_iter < kNumIterations; ++k_iter) + func(k_iter, DivisibleK{}); + } + else + { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++k_iter) + func(k_iter, DivisibleK{}); + func(kNumIterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr int kNumTMARegisters = 40; + constexpr int kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = SchedulerType(problem_input); + + if (threadIdx.x >= kNumMathThreads) + { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) + { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) + { + launch_k_iterations( + [&](int k_iter, auto type) + { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages + = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++s) + { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + // Issue TMA A (weight) now without broadcasting + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), smem_a[s], k_idx, + scheduler.get_global_m_idx(SHAPE_M, BLOCK_M, m_block_idx, n_block_idx), 1); + + // Issue TMA B (act) with broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_n_idx(n_block_idx), kNumTMAMulticast); + + // Issue TMA scales_b (act scales) for B matrix + tma_copy(&tensor_map_scales_b, + reinterpret_cast(&full_barrier), smem_scales_b[s], + scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K, kNumTMAMulticast); + + full_barrier.arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE); + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) + { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) + { +#pragma unroll + for (uint32_t s = 0; s < kNumStages; ++s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } + else + { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + auto const math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + + // Each thread loads consecutive 2 scales + const uint32_t scale_offset = (lane_idx % 4) * 2; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) + { + // Load weight scales (scales_a) - these are associated with tensor_map_a (weight) + // Decide the number of scales A to load + DG_STATIC_ASSERT(SHAPE_M % 8 == 0, "Invalid shape M"); + uint32_t num_scales_a = SHAPE_K_SCALES; + + // Load A scales with math warp-groups (weight scales) + if (threadIdx.x >= 32) + { + auto num_previous_lines + = scheduler.get_global_scales_a_idx(ceil_div(SHAPE_M, BLOCK_K), 0, 0, n_block_idx); + auto local_scales_a + = scales_a + (num_previous_lines + ((m_block_idx * BLOCK_M) / BLOCK_K)) * SHAPE_K_SCALES; +#pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_a; i += kNumMathThreads - 32) + st_shared(smem_scales_a + i, __ldg(local_scales_a + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) + { + if constexpr (kNumTMAMulticast == 1) + { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } + else + { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs + launch_k_iterations( + [&](int k_iter, auto type) + { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages + = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (int s = 0; s < kNumInnerStages; ++s) + { + // Read weight scales (A scales) + float scale_a_0 = ld_shared(smem_scales_a + k_iter * kNumStages + s); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled + // block polluting the results + // Each thread reads consecutive two b scales, each thread needs to read WGMMA::N / 4 * 2 b + // scales + float scale_0_0[WGMMA::kNumAccum / 4], scale_0_1[WGMMA::kNumAccum / 4]; +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + { + float2 scale_b + = ld_shared(reinterpret_cast(smem_scales_b[s] + i * 8 + scale_offset)); + scale_0_0[i] = scale_a_0 * scale_b.x; + scale_0_1[i] = scale_a_0 * scale_b.y; + } + +// Commit WGMMA instructions +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); +#pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++k) + { + auto desc_a + = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + +// Promote with scales +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + { + final_accum[i * 4 + 0] += scale_0_0[i] * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1[i] * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_0_0[i] * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_0_1[i] * accum[i * 4 + 3]; + } + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) + { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + int tid = 0; + if (lane_idx < 8) + { + tid = lane_idx * BLOCK_M; + } + else if (lane_idx < 16) + { + tid = (lane_idx - 8) * BLOCK_M + 8; + } + else if (lane_idx < 24) + { + tid = (lane_idx - 8) * BLOCK_M; + } + else + { + tid = (lane_idx - 16) * BLOCK_M + 8; + } +#pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++i) + { + SM90_U32x4_STSM_T::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + warp_idx * 16 + i * 16 * BLOCK_M + tid); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) + { + SM90_U32x2_STSM_T::copy(__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], + final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn( + {final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + warp_idx * 16 + WGMMA::kNumAccum / 8 * 16 * BLOCK_M + tid); + } + + auto n_global_idx = scheduler.get_global_n_idx(n_block_idx); + bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary; + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + if (!cross_boundary) + { + // Use TMA store to write back to global memory + if (threadIdx.x == 0) + { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + } + else + { + __nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M; + constexpr int NUM_WARPS + = (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem(gmem_d_this_block, smem_d, n_global_idx, + scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} }; // namespace deep_gemm -#pragma clang diagnostic pop \ No newline at end of file +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index 85b2ccc0..4fc7f4fa 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -32,6 +32,30 @@ struct SM90_U32x4_STSM_N { } }; +template +struct SM90_U32x2_STSM_T +{ + __device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst) + { + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16.trans [%0], {%1, %2};\n" ::"l"(smem_dst), "r"(src[0]), + "r"(src[1])); + } +}; + +template +struct SM90_U32x4_STSM_T +{ + __device__ __forceinline__ static void copy( + dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) + { + const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), + *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; + asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" ::"l"(smem_dst), + "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); + } +}; + __forceinline__ __device__ void warpgroup_arrive() { asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); } diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 69ea2160..dacf5f1a 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -7,7 +7,8 @@ namespace deep_gemm { enum class GemmType { Normal, GroupedContiguous, - GroupedMasked + GroupedMasked, + GroupedWithOffset }; #pragma clang diagnostic push @@ -158,6 +159,266 @@ struct Scheduler { } }; + +template +__device__ __forceinline__ void offset_get_swizzled_block_idx( + const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) +{ + DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; + auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = in_group_idx / num_n_blocks_in_group; + n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; +} + + + +struct GroupedWithOffsetSchedulerInput +{ + uint32_t shape_m; + int64_t* problem_m_offsets; +}; + +struct GroupedWithOffsetSchedulerInputSwapAB +{ + uint32_t shape_m; + int64_t* problem_n_offsets; +}; + +struct StridedBatchedSchedulerInput +{ + uint32_t shape_m; + uint64_t ld_a; + uint64_t stride_a; + uint64_t ld_b; + uint64_t stride_b; + uint64_t ld_d; + uint64_t stride_d; +}; + +struct StridedBatchedSchedulerInputSwapAB +{ + uint32_t shape_n; + uint64_t ld_a; + uint64_t stride_a; + uint64_t ld_b; + uint64_t stride_b; + uint64_t ld_d; + uint64_t stride_d; +}; + + +// Need to keep the same as the one in tests/unittest/_torch/thop/deep_gemm_tests.py +template +__host__ __device__ __forceinline__ T_offset compute_padded_offset(T_offset offset, T_index problem_idx) +{ + // This formulation ensures that padded_offset[i + 1] - padded_offset[i] >= offset[i + 1] - offset[i]. + constexpr T_offset alignment = 32; + return (offset + problem_idx * (alignment - 1)) / alignment * alignment; +} + +template +struct GroupedWithOffsetScheduler +{ + static constexpr GemmType gemm_type = GemmType::GroupedWithOffset; + + int current_iter = -1; + uint32_t curr_group_idx; + uint32_t curr_cumsum; + int64_t m_offset; + int64_t m_padded_4_offset; + int64_t m_boundary; + int64_t* problem_m_offsets; + + using Input = GroupedWithOffsetSchedulerInput; + Input input; + + GroupedWithOffsetScheduler() {} + + __device__ __forceinline__ GroupedWithOffsetScheduler(Input& input) + { + this->problem_m_offsets = input.problem_m_offsets; + curr_group_idx = 0; + curr_cumsum = 0; + } + + __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) + { + return m_offset + block_idx * BLOCK_M; + } + + __device__ __forceinline__ uint32_t get_global_n_idx( + uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) + { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) + { + return m_padded_4_offset + block_idx * BLOCK_M; + } + + __device__ __forceinline__ uint32_t get_global_scales_b_idx( + uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) + { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) + { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + uint32_t num_m_blocks; + while (true) + { + // End of the task + if (curr_group_idx == kNumGroups) + return false; + m_offset = __ldg(problem_m_offsets + curr_group_idx); + m_boundary = __ldg(problem_m_offsets + curr_group_idx + 1); + m_padded_4_offset = compute_padded_offset(m_offset, curr_group_idx); + auto m = m_boundary - m_offset; + // Within current group + num_m_blocks = ceil_div(m, static_cast(BLOCK_M)); + auto current_m_block_cumsum = curr_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * kNumNBlocks) + break; + + // Move to check the next group + curr_group_idx++; + curr_cumsum = current_m_block_cumsum; + } + + offset_get_swizzled_block_idx( + num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); + return true; + } +}; + +template +struct GroupedWithOffsetSchedulerSwapAB +{ + static constexpr GemmType gemm_type = GemmType::GroupedWithOffset; + + int current_iter = -1; + uint32_t curr_group_idx; + uint32_t curr_cumsum; + int64_t n_offset; + int64_t n_padded_4_offset; + int64_t n_boundary; + int64_t* problem_n_offsets; + + using Input = GroupedWithOffsetSchedulerInputSwapAB; + Input input; + + GroupedWithOffsetSchedulerSwapAB() {} + + __device__ __forceinline__ GroupedWithOffsetSchedulerSwapAB(Input& input) + { + this->problem_n_offsets = input.problem_n_offsets; + curr_group_idx = 0; + curr_cumsum = 0; + } + + // weight + __device__ __forceinline__ uint32_t get_global_m_idx( + const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0) + { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + // act + __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx) + { + return n_offset + block_idx * BLOCK_N; + } + + // act scales + __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx) + { + return n_padded_4_offset + block_idx * BLOCK_N; + } + + // weight scales + __device__ __forceinline__ uint32_t get_global_scales_a_idx( + const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0) + { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) + { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + uint32_t num_n_blocks; + while (true) + { + // End of the task + if (curr_group_idx == kNumGroups) + return false; + n_offset = __ldg(problem_n_offsets + curr_group_idx); + n_boundary = __ldg(problem_n_offsets + curr_group_idx + 1); + n_padded_4_offset = compute_padded_offset(n_offset, curr_group_idx); + auto n = n_boundary - n_offset; + // Within current group + num_n_blocks = ceil_div(n, static_cast(BLOCK_N)); + auto current_n_block_cumsum = curr_cumsum + num_n_blocks; + if (next_block_idx < current_n_block_cumsum * kNumMBlocks) + break; + + // Move to check the next group + curr_group_idx++; + curr_cumsum = current_n_block_cumsum; + } + + offset_get_swizzled_block_idx( + num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx); + return true; + } +}; + +template +struct SchedulerSelector +{ + static constexpr auto select_type() + { + if constexpr (GT == GemmType::GroupedWithOffset) + return GroupedWithOffsetScheduler(); + } + + using type = decltype(select_type()); +}; + +template +struct SchedulerSelectorSwapAB +{ + static constexpr auto select_type() + { + static_assert(GT == GemmType::GroupedWithOffset || GT == GemmType::Normal, + "Only GroupedWithOffset and Normal are supported for SwapAB"); + if constexpr (GT == GemmType::Normal) + return NormalSchedulerSwapAB(); + if constexpr (GT == GemmType::GroupedWithOffset) + return GroupedWithOffsetSchedulerSwapAB(); + } + + using type = decltype(select_type()); +}; + #pragma clang diagnostic pop } // namespace deep_gemm diff --git a/deep_gemm/jit_kernels/__init__.py b/deep_gemm/jit_kernels/__init__.py index f1fa7bb2..839a3a19 100644 --- a/deep_gemm/jit_kernels/__init__.py +++ b/deep_gemm/jit_kernels/__init__.py @@ -1,7 +1,8 @@ from .gemm import gemm_fp8_fp8_bf16_nt from .m_grouped_gemm import ( m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, - m_grouped_gemm_fp8_fp8_bf16_nt_masked + m_grouped_gemm_fp8_fp8_bf16_nt_masked, + m_grouped_gemm_fp8_fp8_bf16_nt_offset ) from .wgrad_gemm import ( wgrad_gemm_fp8_fp8_fp32_nt, diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 574f821f..64bcc76a 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -34,49 +34,75 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int: def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128, - is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]: + is_fp32_out: bool = False, is_wgrad: bool = False, is_swap_ab: bool = False) -> Tuple[int, int, int]: assert block_k == 128 - # Try swizzle first, as it does not waste shared memory - swizzle_mode = get_swizzle_mode(block_n) - block_n_padding = get_block_n_padding_for_smem_d( - block_n) if swizzle_mode == 0 else 0 - - # NOTES: `scales_b` in a total manner or per-stage manner - smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) - smem_a_per_stage = block_m * block_k - smem_scales_a_per_stage = block_m * 4 - smem_b_per_stage = block_n * block_k - smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0 - smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0 - smem_barrier = num_stages * 8 * 2 - - smem_size = 0 - smem_size += smem_d - smem_size += num_stages * smem_a_per_stage - smem_size += num_stages * smem_scales_a_per_stage - smem_size += num_stages * smem_b_per_stage - smem_size += num_stages * smem_scales_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 - smem_size += smem_barrier - - # Swizzle and padding are not compatible - assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 - - return smem_size, swizzle_mode, block_n_padding + if not is_swap_ab: + # Try swizzle first, as it does not waste shared memory + swizzle_mode = get_swizzle_mode(block_n) + block_n_padding = get_block_n_padding_for_smem_d( + block_n) if swizzle_mode == 0 else 0 + + # NOTES: `scales_b` in a total manner or per-stage manner + smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) + smem_a_per_stage = block_m * block_k + smem_scales_a_per_stage = block_m * 4 + smem_b_per_stage = block_n * block_k + smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0 + smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0 + smem_barrier = num_stages * 8 * 2 + + smem_size = 0 + smem_size += smem_d + smem_size += num_stages * smem_a_per_stage + smem_size += num_stages * smem_scales_a_per_stage + smem_size += num_stages * smem_b_per_stage + smem_size += num_stages * smem_scales_b_per_stage + smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 + smem_size += smem_barrier + + # Swizzle and padding are not compatible + assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 + + return smem_size, swizzle_mode, block_n_padding + else: + # NOTES: `scales_b` in a total manner or per-stage manner + smem_d = block_m * block_n * (4 if is_fp32_out else 2) + smem_a_per_stage = block_m * block_k + smem_scales_a_per_stage = ceil_div(k, block_k) * 4; # weight scales + smem_b_per_stage = block_n * block_k + smem_scales_b_per_stage = 0 # swap_ab not support wgrad + smem_scales_b = ceil_div(block_n * 4, 128) * 128 # swap_ab not support wgrad + smem_barrier = num_stages * 8 * 2 + + smem_size = 0 + smem_size += smem_d + smem_size += num_stages * smem_a_per_stage + smem_size += num_stages * smem_scales_b + smem_size += num_stages * smem_b_per_stage + smem_size += num_stages * smem_scales_b_per_stage + smem_size += ceil_div(smem_scales_a_per_stage, 8) * 8 + smem_size += smem_barrier + + # no swizzle, no block_n_padding + swizzle_mode = 0 + block_n_padding = 0 + + return smem_size, swizzle_mode, block_n_padding @lru_cache(maxsize=None) def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, is_grouped_contiguous: bool = False, is_grouped_masked: bool = False, - is_fp32_out: bool = False, is_wgrad: bool = False) -> \ + is_fp32_out: bool = False, is_wgrad: bool = False, is_swap_ab: bool = False) -> \ Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]: if not is_grouped_contiguous: block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ()) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) - + #block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) + block_ns = tuple(range(16, 129, 8)) + # Avoid bank conflicts for FP32 output if is_fp32_out: block_ns = [x for x in block_ns if x % 16 == 8] @@ -119,7 +145,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Unrolling both stages and `num_former_iters` will cause large code size stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1))) for num_stages in stage_candidates: - best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad) + best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad, is_swap_ab = is_swap_ab) if best_smem_config[0] <= sm90_capacity: best_num_stages = num_stages break @@ -131,21 +157,39 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Try to multicast on the larger block side first # NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even - is_multicast_legal = { - 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked), - 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked, - } - for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): - if m >= 512 and is_multicast_legal[i]: - best_tma_multicast_config = (2, i == 'A') - break - # Recompute the minimal number of SMs required - # NOTES: less L2 cache usage and less GPU frequency drop - num_waves = get_num_waves(best_block_m, best_block_n) - num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] - assert num_min_sms <= num_sms + if not is_swap_ab: + is_multicast_legal = { + 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked), + 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked, + } + for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): + if m >= 512 and is_multicast_legal[i]: + best_tma_multicast_config = (2, i == 'A') + break + + # Recompute the minimal number of SMs required + # NOTES: less L2 cache usage and less GPU frequency drop + num_waves = get_num_waves(best_block_m, best_block_n) + num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) + num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] + assert num_min_sms <= num_sms + else: + is_multicast_legal = { + 'A': is_tma_multicast_legal(n, best_block_m, 2, num_sms), + 'B': is_tma_multicast_legal(m, best_block_n, 2, num_sms), + } + for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): + if n >= 512 and is_multicast_legal[i]: + best_tma_multicast_config = (2, i == 'B') + break + + # Recompute the minimal number of SMs required + # NOTES: less L2 cache usage and less GPU frequency drop + num_waves = get_num_waves(best_block_n, best_block_m) + num_min_sms = ceil_div(ceil_div(n, best_block_m) * ceil_div(m, best_block_n) * num_groups, num_waves) + num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] + assert num_min_sms <= num_sms return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index ca2fc79a..94522db3 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -4,10 +4,12 @@ from ..jit import build from .gemm import get_best_configs from .runtime import ( - FP8GemmRuntime, GemmType, + FP8GemmRuntime, FP8GemmOffsetRuntime, GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_desc) -from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms + make_2d_tma_d_desc, make_2d_tma_scales_desc, + make_2d_tma_scales_a_offset_desc, + make_2d_tma_a_offset_desc_swapAB, make_2d_tma_b_offset_desc_swapAB, make_2d_tma_d_offset_desc_swapAB, make_2d_tma_scales_b_offset_desc_swapAB) +from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms, compute_padded_offset def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -203,3 +205,148 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] code = FP8GemmRuntime.generate(kwargs) runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) runtime(**kwargs) + + +def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + offsets: torch.Tensor, + out: torch.Tensor, expected_m: int) -> None: + """ + GroupedWithOffset from TensorRT-LLM + """ + + lhs, lhs_scales = lhs + rhs, rhs_scales = rhs + m, k = lhs.shape + num_groups, n, k_ = rhs.shape + m_, n_ = out.shape + + + # Type and shape checks + assert m == m_ and n == n_ and k == k_ + + max_shape_m_4_align = ceil_div(m, 4) * 4 # align 4 + max_shape_m_32_align_padded = compute_padded_offset(max_shape_m_4_align, num_groups) + + assert expected_m > 0 and max_shape_m_4_align > 0 and n > 0 and k > 0 and num_groups > 0 + + + # if compute_padded_offset ? + #assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128)) + assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) + assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 + assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 + assert offsets.dtype == torch.int64 + assert out.dtype == torch.bfloat16 + assert lhs.is_contiguous() and rhs.is_contiguous() + assert out.is_contiguous() + + # LHS scales must be transposed for TMA load, but not for RHS scales + + #lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) + assert rhs_scales.is_contiguous() + + # Auto-tuning with compilation + num_sms = get_num_sms() + + if num_sms==78: + m_per_expert_threshold = 64 # H20 + else: + m_per_expert_threshold = 32 # H100 + + if expected_m> m_per_expert_threshold: + + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + expected_m, n, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=False) + + # Extra checks for TMA store + if num_groups > 1 and m > block_m: + assert m % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' + + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, max_shape_m_4_align, k, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedWithOffset, rhs, n, k, k, block_n, block_k, num_groups) + tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, max_shape_m_4_align, n, n, block_m, block_n, num_groups, 0) # none swizzle + tensor_map_scales_a = make_2d_tma_scales_a_offset_desc(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_m, block_k) # none swizzle + + + kwargs = { + # Templated arguments + 'KERNEL_NAME': 'fp8_gemm_offset_kernel', + 'SCHEDULER_TYPE': 'SchedulerSelector', + 'INPUT_TYPE': 'GroupedWithOffsetSchedulerInput', + 'PROBLEM_OFFSETS': offsets, + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': max_shape_m_4_align, 'N': n, 'K': k, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, + 'NUM_GROUPS': num_groups, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': GemmType.GroupedWithOffset, + # Runtime arguments + 'SCALES': rhs_scales, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + 'DEVICE_INDEX': out.device.index, + 'OUT': out + } + + else: + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + n, expected_m, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=True) + # Extra checks for TMA store + if num_groups > 1 and n > block_m: + assert n % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' + + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_offset_desc_swapAB(GemmType.GroupedWithOffset, rhs, n, k, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, max_shape_m_4_align, k, k, block_n, block_k, num_groups) + tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, max_shape_m_4_align, n, n, block_m, block_n, num_groups, 0) # no swizzle + tensor_map_scales_b = make_2d_tma_scales_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_n, block_k) # no swizzle + + kwargs = { + # Templated arguments + 'KERNEL_NAME': 'fp8_gemm_offset_kernel_swapAB', + 'SCHEDULER_TYPE': 'SchedulerSelectorSwapAB', + 'INPUT_TYPE': 'GroupedWithOffsetSchedulerInputSwapAB', + 'PROBLEM_OFFSETS': offsets, + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': max_shape_m_4_align, 'N': n, 'K': k, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, + 'NUM_GROUPS': num_groups, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': GemmType.GroupedWithOffset, + # Runtime arguments + 'SCALES': rhs_scales, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES': tensor_map_scales_b, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + 'DEVICE_INDEX': out.device.index, + 'OUT': out + } + + # Generate, build and run the kernel + code = FP8GemmOffsetRuntime.generate(kwargs) + runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt_offset', code, FP8GemmOffsetRuntime, kwargs) + runtime(**kwargs) + diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index e65e85aa..bc37de2e 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -5,7 +5,7 @@ import cuda.bindings.driver as cbd from typing import Any, Dict, Tuple -from .utils import get_tma_aligned_size +from .utils import get_tma_aligned_size, ceil_div from ..jit.runtime import Runtime @@ -13,12 +13,15 @@ class GemmType(enum.Enum): Normal = 0 GroupedContiguous = 1 GroupedMasked = 2 + GroupedWithOffset = 3 + def __str__(self) -> str: return { 0: 'Normal', 1: 'GroupedContiguous', 2: 'GroupedMasked', + 3: 'GroupedWithOffset', }[self.value] @@ -133,6 +136,58 @@ def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) +def make_2d_tma_scales_a_offset_desc(gemm_type: GemmType, t: torch.Tensor, + max_m_padded_total: int, shape_k: int, + block_m: int, block_k: int, + global_stride_in_bytes: int = 0) -> cbd.CUtensorMap: + return make_2d_tma_desc(t, + max_m_padded_total, ceil_div(shape_k, block_k), max_m_padded_total, + block_m, 1, + cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + + +def make_2d_tma_a_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor, + shape_m: int, shape_k: int, m_stride: int, + block_m: int, block_k: int, + num_groups: int) -> cbd.CUtensorMap: + return make_2d_tma_desc(t, + shape_k, shape_m * (num_groups if gemm_type != GemmType.Normal else 1), m_stride, + block_k, block_m) + + +def make_2d_tma_b_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor, + shape_n: int, shape_k: int, n_stride: int, + block_n: int, block_k: int, + num_groups: int) -> cbd.CUtensorMap: + return make_2d_tma_desc(t, + shape_k, shape_n * (num_groups if gemm_type == GemmType.GroupedMasked else 1), n_stride, + block_k, block_n) + + +def make_2d_tma_d_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor, + shape_m: int, shape_n: int, m_stride: int, + block_m: int, block_n: int, + num_groups: int, + swizzle_mode: int) -> cbd.CUtensorMap: + # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` + # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required + return make_2d_tma_desc(t, + shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride, + min(block_m, shape_n), min(block_n, shape_m), + cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + +def make_2d_tma_scales_b_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor, + max_n_padded_total: int, shape_k: int, + block_n: int, block_k: int, + global_stride_in_bytes: int = 0) -> cbd.CUtensorMap: + return make_2d_tma_desc(t, + max_n_padded_total, ceil_div(shape_k, block_k), max_n_padded_total, + block_n, 1, + cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + class FP8GemmRuntime(Runtime): def __init__(self, path: str) -> None: super().__init__(path) @@ -316,3 +371,101 @@ def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: None, ) return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) + + +class FP8GemmOffsetRuntime(Runtime): + def __init__(self, path: str) -> None: + super().__init__(path) + + @staticmethod + def generate(kwargs: Dict[str, Any]) -> str: + code = f''' +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif + +#include +#include + +#include + +using namespace deep_gemm; + +using SchedulerType = +typename {kwargs['SCHEDULER_TYPE']} ::type; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&{kwargs['KERNEL_NAME']}< + {kwargs['N']}, + {kwargs['K']}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['NUM_GROUPS']}, + {kwargs['NUM_STAGES']}, + {kwargs['NUM_TMA_THREADS']}, + {kwargs['NUM_MATH_THREADS_PER_GROUP']}, + {kwargs['NUM_TMA_MULTICAST']}, + SchedulerType, + {kwargs['INPUT_TYPE']} + >); +}}; +''' + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Generated FP8 GEMM code:\n{code}') + return code + + # noinspection PyMethodOverriding + @staticmethod + def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: + num_tma_threads = 128 + num_math_threads_per_group = 128 + + result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0] + assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}' + + attr_val = cbd.CUlaunchAttributeValue() + attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST'] + attr_val.clusterDim.y = 1 + attr_val.clusterDim.z = 1 + attr = cbd.CUlaunchAttribute() + attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attr.value = attr_val + + config = cbd.CUlaunchConfig() + config.numAttrs = 1 + config.attrs = [attr] + config.gridDimX = kwargs['NUM_SMS'] + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M']) + config.blockDimY = 1 + config.blockDimZ = 1 + config.sharedMemBytes = kwargs['SMEM_SIZE'] + config.hStream = kwargs['STREAM'] + + arg_values = ( + kwargs['OUT'].data_ptr(), + kwargs['SCALES'].data_ptr(), + kwargs['PROBLEM_OFFSETS'].data_ptr(), + kwargs['TENSOR_MAP_A'], + kwargs['TENSOR_MAP_B'], + kwargs['TENSOR_MAP_SCALES'], + kwargs['TENSOR_MAP_D'], + ) + arg_types = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + None, + None, + None, + None, + ) + return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py index c6da56b0..11a42bdf 100644 --- a/deep_gemm/jit_kernels/utils.py +++ b/deep_gemm/jit_kernels/utils.py @@ -107,3 +107,6 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: aligned_x[:, :m, :] = x aligned_x = aligned_x[:, :m, :] return aligned_x.squeeze(0) if remove_dim else aligned_x + +def compute_padded_offset(offset, idx_problem, alignment=32): + return (offset + idx_problem * (alignment - 1)) // alignment * alignment diff --git a/tests/test_core.py b/tests/test_core.py index 3b88539c..1a3565b1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,6 +6,7 @@ import random import torch from typing import List, Tuple +import itertools import deep_gemm from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor @@ -167,6 +168,81 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes +def change_to_offset_layout( + ms: List[int], + x_fp8: torch.Tensor, + x_scale: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + x_list = [] + x_scale_list = [] + shape_m_total = 0 + num_problems = len(ms) + m_acc = [0] + list(itertools.accumulate(ms)) + + # Need to keep the same as the one in cpp/include/tensorrt_llm/deep_gemm/scheduler.cuh + def compute_padded_offset(offset, idx_problem, alignment=32): + return (offset + idx_problem * (alignment - 1)) // alignment * alignment + + offset = 0 + for i in range(num_problems): + ms[i] + x_list.append(x_fp8[m_acc[i]:m_acc[i + 1]]) + offset_next = compute_padded_offset(m_acc[i + 1], i + 1) + size_padded = (offset_next - offset) - (m_acc[i + 1] - m_acc[i]) + x_scale_padded = torch.cat([ + x_scale[m_acc[i]:m_acc[i + 1]], + torch.zeros( + [size_padded, *x_scale.shape[1:]], + dtype=x_scale.dtype, + device=x_scale.device, + ), + ]) + x_scale_list.append(x_scale_padded) + offset = offset_next + + shape_m_total = m_acc[-1] + ret_x = torch.cat(x_list) + ret_x_scale = torch.cat(x_scale_list) + ret_x_scale = ret_x_scale.t().contiguous() + pad_target = compute_padded_offset(shape_m_total, num_problems) + pad_target -= ret_x_scale.shape[1] + ret_x_scale = torch.nn.functional.pad(ret_x_scale, (0, pad_target), + mode='constant', + value=0) + return ret_x, ret_x_scale + + +def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ + Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + alignment = 4 + group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + + m = sum([ceil_div(x, alignment) * alignment for x in group_ms]) + + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + offsets = torch.empty(num_groups+1, device='cuda', dtype=torch.int64) + out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) + + start = 0 + offsets[0] = 0 + for i, group_m in enumerate(group_ms): + aligned_end = start + ceil_div(group_m, alignment) * alignment + offsets[i+1] = aligned_end + ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() + start = aligned_end + group_ms[i] = ceil_div(group_m, alignment) * alignment + + assert m % 4 == 0, f'TMA alignment error: {m}' + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) + for i in range(num_groups): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + return group_ms, m, x_fp8, y_fp8, offsets.type(torch.int64), out, ref_out + + def test_gemm() -> None: print('Testing GEMM:') for m in (64, 128, 4096): @@ -295,6 +371,32 @@ def test_func(): print() +def test_m_grouped_gemm_offset() -> None: + print('Testing grouped offset GEMM:') + + for num_groups, expected_m_per_group in ((2, 16), (4, 16), (2, 32), (9, 32), (2, 32), (4, 32), (32, 64)): + for k, n in ((7168, 4096),): + # NOTES: we should mask the unfilled part before calculating difference + ms, m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n) + pad_x_fp8 = change_to_offset_layout(ms, x_fp8_offset[0], x_fp8_offset[1]) + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group) + diff = calc_diff(out_offset, ref_out_offset) + assert diff < 0.001, f'{m_offset=}, {k=}, {n=}, {diff:.5f}' + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + valid_m = m_offset + + print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + if __name__ == '__main__': torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -307,6 +409,7 @@ def test_func(): test_gemm() test_m_grouped_gemm_contiguous() test_m_grouped_gemm_masked() + test_m_grouped_gemm_offset() test_wgrad_gemm() test_k_grouped_wgrad_gemm()