|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#pragma once |
| 10 | + |
| 11 | +#include <cute/tensor.hpp> |
| 12 | + |
| 13 | +namespace fbgemm_gpu { |
| 14 | + |
| 15 | +enum GroupedGemmInputType { |
| 16 | + // K dynamic |
| 17 | + _2D2D, |
| 18 | + // M dynamic (MoE forward style) |
| 19 | + _2D3D |
| 20 | +}; |
| 21 | + |
| 22 | +template < |
| 23 | + typename ProblemShape, |
| 24 | + typename ElementA, |
| 25 | + typename ElementB, |
| 26 | + typename ElementC, |
| 27 | + typename ScaleDtype, |
| 28 | + typename StrideA, |
| 29 | + typename StrideB, |
| 30 | + typename StrideC, |
| 31 | + typename LayoutSFA, |
| 32 | + typename LayoutSFB, |
| 33 | + typename Sm1xxBlkScaledConfig> |
| 34 | +__global__ void set_grouped_gemm_args_kernel( |
| 35 | + int64_t G, |
| 36 | + int64_t M, |
| 37 | + int64_t N, |
| 38 | + int64_t K, |
| 39 | + ProblemShape* problem_shape_ptr, |
| 40 | + ElementA* xq, |
| 41 | + const ElementA** xq_ptr, |
| 42 | + ElementB* wq, |
| 43 | + const ElementB** wq_ptr, |
| 44 | + ScaleDtype* x_scale, |
| 45 | + const ScaleDtype** x_scale_ptr, |
| 46 | + ScaleDtype* w_scale, |
| 47 | + const ScaleDtype** w_scale_ptr, |
| 48 | + ElementC* output, |
| 49 | + ElementC** output_ptr, |
| 50 | + StrideA* stride_a_ptr, |
| 51 | + StrideB* stride_b_ptr, |
| 52 | + StrideC* stride_c_ptr, |
| 53 | + int32_t* offsets, // Group end offsets |
| 54 | + LayoutSFA* layout_SFA, |
| 55 | + LayoutSFB* layout_SFB, |
| 56 | + GroupedGemmInputType gemm_type) { |
| 57 | + const uint32_t group_index = blockIdx.x * blockDim.x + threadIdx.x; |
| 58 | + |
| 59 | + // If this thread corresponds to a valid group, write kernel args to device |
| 60 | + // memory. |
| 61 | + if (group_index < G) { |
| 62 | + // Set problem shapes to empty by default. |
| 63 | + problem_shape_ptr[group_index] = ProblemShape(0, 0, 0); |
| 64 | + |
| 65 | + // Offsets for this group. |
| 66 | + int64_t xq_offset = 0; |
| 67 | + int64_t wq_offset = 0; |
| 68 | + int64_t output_offset = 0; |
| 69 | + int64_t x_scale_offset = 0; |
| 70 | + int64_t w_scale_offset = 0; |
| 71 | + |
| 72 | + auto round_up = [](int64_t x, int64_t y) { return ((x + y - 1) / y) * y; }; |
| 73 | + |
| 74 | + // Pre-compute common rounded values to minimize round_up calls |
| 75 | + const int64_t N_rounded = round_up(N, 128); |
| 76 | + const int64_t M_rounded = round_up(M, 128); |
| 77 | + |
| 78 | + const int64_t scale_factor_block_size = 32; |
| 79 | + |
| 80 | + // Handle offsets API (torch compliant API for 2D-2D and 2D-3D inputs) |
| 81 | + CUDA_KERNEL_ASSERT( |
| 82 | + offsets != nullptr && |
| 83 | + "offsets must be set for 2d-2d and 2d-3d grouped GEMMs"); |
| 84 | + switch (gemm_type) { |
| 85 | + // In the 2d-2d case, contraction dim (total_K) has variable group |
| 86 | + // sizes. XQ = (M, total_K) WQ = (N, total_K) Main loop defined with WQ |
| 87 | + // @ XQ^T = (N, M) for each group. out = (G, N, M) |
| 88 | + case GroupedGemmInputType::_2D2D: { |
| 89 | + // `offsets` contains end index of each group. |
| 90 | + const int32_t prev_group_end_offset = |
| 91 | + (group_index == 0) ? 0 : offsets[group_index - 1]; |
| 92 | + const int32_t curr_group_end_offset = offsets[group_index]; |
| 93 | + const int32_t K_group_size = |
| 94 | + curr_group_end_offset - prev_group_end_offset; |
| 95 | + |
| 96 | + // Validate group offsets. |
| 97 | + const int align = 128 / cutlass::sizeof_bits<ElementA>::value; |
| 98 | + CUDA_KERNEL_ASSERT( |
| 99 | + K_group_size % align == 0 && |
| 100 | + "for 2d-2d grouped gemm, group sizes along K dim must be non-negative multiple of 16\n"); |
| 101 | + CUDA_KERNEL_ASSERT( |
| 102 | + curr_group_end_offset <= K && |
| 103 | + "for 2d-2d grouped gemm, group end offsets must be non-negative and must be <= K\n"); |
| 104 | + |
| 105 | + // Set starting input offsets for this group. |
| 106 | + // XQ is shape (M,K) with strides (K, 1) and group offsets are along |
| 107 | + // the K dim, so: xq_offset -> prev_group_end_offset * 1 |
| 108 | + xq_offset = prev_group_end_offset; |
| 109 | + |
| 110 | + // WQ is shape (N,K) with strides (K, 1) and group offsets are along |
| 111 | + // the K dim, so: wq_offset -> prev_group_end_offset * 1 |
| 112 | + wq_offset = prev_group_end_offset; |
| 113 | + |
| 114 | + // Output for 2d-2d grouped GEMM is shape (G, M, N) |
| 115 | + // output_offset -> group_index rows with stride of M * N |
| 116 | + output_offset = group_index * M * N; |
| 117 | + |
| 118 | + // Group sizes are variable and converted to blocked/padded format, so |
| 119 | + // to calculate the starting offset of this group's scales, we do the |
| 120 | + // following: For each previous group |
| 121 | + // - Calculate the expected size of its blocked formatted scales |
| 122 | + // - Increment the scale offsets by that size |
| 123 | + // x_scale shape (M_rounded, total_K_padded_per_group). |
| 124 | + // w_scale has shape (N_rounded, total_K_padded_per_group). |
| 125 | + for (int i = 0; i < group_index; i++) { |
| 126 | + int group_i_size = i == 0 ? offsets[i] : offsets[i] - offsets[i - 1]; |
| 127 | + int scale_cols_for_group_i_padded = |
| 128 | + round_up(group_i_size / scale_factor_block_size, 4); |
| 129 | + x_scale_offset += M_rounded * scale_cols_for_group_i_padded; |
| 130 | + w_scale_offset += N_rounded * scale_cols_for_group_i_padded; |
| 131 | + } |
| 132 | + |
| 133 | + // Only write kernel args if this group is non-empty |
| 134 | + if (K_group_size > 0) { |
| 135 | + // Get index automatically for this group |
| 136 | + int total_K = K; // Name alias for clarity/readability. |
| 137 | + |
| 138 | + // Set problem shape. |
| 139 | + // Main loop passes inputs in B,A order, so we have: (N, K_group) @ |
| 140 | + // (M, K_group)^T = (N, M) for each group. |
| 141 | + problem_shape_ptr[group_index] = ProblemShape(N, M, K_group_size); |
| 142 | + |
| 143 | + // Set pointers for this group. |
| 144 | + xq_ptr[group_index] = xq + xq_offset; |
| 145 | + wq_ptr[group_index] = wq + wq_offset; |
| 146 | + x_scale_ptr[group_index] = x_scale + x_scale_offset; |
| 147 | + w_scale_ptr[group_index] = w_scale + w_scale_offset; |
| 148 | + output_ptr[group_index] = output + output_offset; |
| 149 | + |
| 150 | + // Set strides. |
| 151 | + // TODO: make strides configurable to handle all NT/TN/NN/NT layouts |
| 152 | + // that Blackwell supports. For XQ, the group processes a slice (M, |
| 153 | + // K_group_size) but it's part of a larger tensor (M, total_K). The |
| 154 | + // stride needs to reflect that rows are separated by total_K |
| 155 | + // elements in the original tensor. |
| 156 | + stride_a_ptr[group_index] = cutlass::make_cute_packed_stride( |
| 157 | + StrideA{}, cute::make_shape(int(M), int(total_K), 1)); |
| 158 | + |
| 159 | + // For WQ, the group processes a slice (N, K_group_size) but it's |
| 160 | + // part of a larger tensor (N, total_K). The stride needs to reflect |
| 161 | + // that rows are separated by total_K elements in the original |
| 162 | + // tensor. |
| 163 | + stride_b_ptr[group_index] = cutlass::make_cute_packed_stride( |
| 164 | + StrideB{}, cute::make_shape(int(N), int(total_K), 1)); |
| 165 | + |
| 166 | + // For output of this group, (M, K_group_size) @ (N, K_group_size)^T |
| 167 | + // = (M, N) |
| 168 | + stride_c_ptr[group_index] = cutlass::make_cute_packed_stride( |
| 169 | + StrideC{}, cute::make_shape(int(N), int(M), 1)); |
| 170 | + |
| 171 | + // Set layouts for scale factors. |
| 172 | + // Groups of variable size are along the K dim, so we need to |
| 173 | + // calculate the size of the blocked group scale factor here. |
| 174 | + layout_SFA[group_index] = |
| 175 | + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( |
| 176 | + cute::make_shape(int(M), int(N), int(K_group_size), 1)); |
| 177 | + layout_SFB[group_index] = |
| 178 | + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( |
| 179 | + cute::make_shape(int(M), int(N), int(K_group_size), 1)); |
| 180 | + } |
| 181 | + break; |
| 182 | + } |
| 183 | + case GroupedGemmInputType::_2D3D: { |
| 184 | + // `offsets` contains end index of each group. |
| 185 | + const int32_t prev_group_end_offset = |
| 186 | + (group_index == 0) ? 0 : offsets[group_index - 1]; |
| 187 | + const int32_t curr_group_end_offset = offsets[group_index]; |
| 188 | + const int32_t M_group_size = |
| 189 | + curr_group_end_offset - prev_group_end_offset; |
| 190 | + |
| 191 | + if (M_group_size > 0) { |
| 192 | + // Validate group offsets. |
| 193 | + CUDA_KERNEL_ASSERT( |
| 194 | + curr_group_end_offset <= M && |
| 195 | + "for 2d-3d grouped gemm, group end offsets must be non-negative and must be <= M\n"); |
| 196 | + |
| 197 | + // Compute starting offset for this group when M_group size > 0 |
| 198 | + int64_t group_offset_M = |
| 199 | + group_index == 0 ? 0 : offsets[group_index - 1]; |
| 200 | + int64_t scale_group_offset_M = 0; |
| 201 | + for (int i = 0; i < group_index; i++) { |
| 202 | + // Group offset on XQ along total_M dim is the sum of all previous |
| 203 | + // group sizes. |
| 204 | + int group_i_size = |
| 205 | + i == 0 ? offsets[i] : offsets[i] - offsets[i - 1]; |
| 206 | + |
| 207 | + // Scale group offset on x_scale is sum of all previous scale |
| 208 | + // group sizes. |
| 209 | + int scale_group_rows_padded = round_up(group_i_size, 128); |
| 210 | + scale_group_offset_M += scale_group_rows_padded; |
| 211 | + } |
| 212 | + |
| 213 | + // wq_offset -> group_offset_M rows with stride of K |
| 214 | + xq_offset = group_offset_M * K; |
| 215 | + |
| 216 | + // wq_offset -> group_index rows with stride of N * K (3d tensor) |
| 217 | + wq_offset = group_index * N * K; |
| 218 | + |
| 219 | + // output_offset -> group_offset_M rows with stride of N |
| 220 | + output_offset = group_offset_M * N; |
| 221 | + |
| 222 | + // x_scale offset -> sum of all padded group sizes (rows) * rounded |
| 223 | + // scale group cols |
| 224 | + const int64_t K_rounded = round_up(K / scale_factor_block_size, 4); |
| 225 | + x_scale_offset = scale_group_offset_M * K_rounded; |
| 226 | + |
| 227 | + // w_scale_offset -> group_index rows with stride of (N rounded to |
| 228 | + // nearest multiple of 128 * K rounded to nearest multiple of 4) |
| 229 | + w_scale_offset = group_index * N_rounded * K_rounded; |
| 230 | + |
| 231 | + // Set problem shape |
| 232 | + problem_shape_ptr[group_index] = ProblemShape(N, M_group_size, K); |
| 233 | + |
| 234 | + // Set pointers |
| 235 | + xq_ptr[group_index] = xq + xq_offset; |
| 236 | + wq_ptr[group_index] = wq + wq_offset; |
| 237 | + x_scale_ptr[group_index] = x_scale + x_scale_offset; |
| 238 | + w_scale_ptr[group_index] = w_scale + w_scale_offset; |
| 239 | + output_ptr[group_index] = output + output_offset; |
| 240 | + |
| 241 | + // Set strides |
| 242 | + stride_a_ptr[group_index] = cutlass::make_cute_packed_stride( |
| 243 | + StrideA{}, cute::make_shape(int(M_group_size), int(K), 1)); |
| 244 | + stride_b_ptr[group_index] = cutlass::make_cute_packed_stride( |
| 245 | + StrideB{}, cute::make_shape(int(N), int(K), 1)); |
| 246 | + stride_c_ptr[group_index] = cutlass::make_cute_packed_stride( |
| 247 | + StrideC{}, cute::make_shape(int(N), int(M_group_size), 1)); |
| 248 | + |
| 249 | + // Set layouts for scale factors |
| 250 | + layout_SFA[group_index] = |
| 251 | + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( |
| 252 | + cute::make_shape(int(M_group_size), int(N), int(K), 1)); |
| 253 | + layout_SFB[group_index] = |
| 254 | + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( |
| 255 | + cute::make_shape(int(M_group_size), int(N), int(K), 1)); |
| 256 | + } |
| 257 | + break; |
| 258 | + } |
| 259 | + } |
| 260 | + } |
| 261 | +} |
| 262 | + |
| 263 | +} // namespace fbgemm_gpu |
0 commit comments