Skip to content

Commit 8e6bff8

Browse files
cthifacebook-github-bot
authored andcommitted
Split grouped gemm metadata kernel into grouped_common.cuh (#4932)
Summary: X-link: facebookresearch/FBGEMM#1955 We would reuse this kernel as a base to add support for NV/MX FP4. As a first step, shuffle it into it's own file. Differential Revision: D83151150
1 parent bb87e43 commit 8e6bff8

File tree

3 files changed

+274
-260
lines changed

3 files changed

+274
-260
lines changed
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cute/tensor.hpp>
12+
13+
namespace fbgemm_gpu {
14+
15+
enum GroupedGemmInputType {
16+
// K dynamic
17+
_2D2D,
18+
// M dynamic (MoE forward style)
19+
_2D3D
20+
};
21+
22+
template <
23+
typename ProblemShape,
24+
typename ElementA,
25+
typename ElementB,
26+
typename ElementC,
27+
typename ScaleDtype,
28+
typename StrideA,
29+
typename StrideB,
30+
typename StrideC,
31+
typename LayoutSFA,
32+
typename LayoutSFB,
33+
typename Sm1xxBlkScaledConfig>
34+
__global__ void set_grouped_gemm_args_kernel(
35+
int64_t G,
36+
int64_t M,
37+
int64_t N,
38+
int64_t K,
39+
ProblemShape* problem_shape_ptr,
40+
ElementA* xq,
41+
const ElementA** xq_ptr,
42+
ElementB* wq,
43+
const ElementB** wq_ptr,
44+
ScaleDtype* x_scale,
45+
const ScaleDtype** x_scale_ptr,
46+
ScaleDtype* w_scale,
47+
const ScaleDtype** w_scale_ptr,
48+
ElementC* output,
49+
ElementC** output_ptr,
50+
StrideA* stride_a_ptr,
51+
StrideB* stride_b_ptr,
52+
StrideC* stride_c_ptr,
53+
int32_t* offsets, // Group end offsets
54+
LayoutSFA* layout_SFA,
55+
LayoutSFB* layout_SFB,
56+
GroupedGemmInputType gemm_type) {
57+
const uint32_t group_index = blockIdx.x * blockDim.x + threadIdx.x;
58+
59+
// If this thread corresponds to a valid group, write kernel args to device
60+
// memory.
61+
if (group_index < G) {
62+
// Set problem shapes to empty by default.
63+
problem_shape_ptr[group_index] = ProblemShape(0, 0, 0);
64+
65+
// Offsets for this group.
66+
int64_t xq_offset = 0;
67+
int64_t wq_offset = 0;
68+
int64_t output_offset = 0;
69+
int64_t x_scale_offset = 0;
70+
int64_t w_scale_offset = 0;
71+
72+
auto round_up = [](int64_t x, int64_t y) { return ((x + y - 1) / y) * y; };
73+
74+
// Pre-compute common rounded values to minimize round_up calls
75+
const int64_t N_rounded = round_up(N, 128);
76+
const int64_t M_rounded = round_up(M, 128);
77+
78+
const int64_t scale_factor_block_size = 32;
79+
80+
// Handle offsets API (torch compliant API for 2D-2D and 2D-3D inputs)
81+
CUDA_KERNEL_ASSERT(
82+
offsets != nullptr &&
83+
"offsets must be set for 2d-2d and 2d-3d grouped GEMMs");
84+
switch (gemm_type) {
85+
// In the 2d-2d case, contraction dim (total_K) has variable group
86+
// sizes. XQ = (M, total_K) WQ = (N, total_K) Main loop defined with WQ
87+
// @ XQ^T = (N, M) for each group. out = (G, N, M)
88+
case GroupedGemmInputType::_2D2D: {
89+
// `offsets` contains end index of each group.
90+
const int32_t prev_group_end_offset =
91+
(group_index == 0) ? 0 : offsets[group_index - 1];
92+
const int32_t curr_group_end_offset = offsets[group_index];
93+
const int32_t K_group_size =
94+
curr_group_end_offset - prev_group_end_offset;
95+
96+
// Validate group offsets.
97+
const int align = 128 / cutlass::sizeof_bits<ElementA>::value;
98+
CUDA_KERNEL_ASSERT(
99+
K_group_size % align == 0 &&
100+
"for 2d-2d grouped gemm, group sizes along K dim must be non-negative multiple of 16\n");
101+
CUDA_KERNEL_ASSERT(
102+
curr_group_end_offset <= K &&
103+
"for 2d-2d grouped gemm, group end offsets must be non-negative and must be <= K\n");
104+
105+
// Set starting input offsets for this group.
106+
// XQ is shape (M,K) with strides (K, 1) and group offsets are along
107+
// the K dim, so: xq_offset -> prev_group_end_offset * 1
108+
xq_offset = prev_group_end_offset;
109+
110+
// WQ is shape (N,K) with strides (K, 1) and group offsets are along
111+
// the K dim, so: wq_offset -> prev_group_end_offset * 1
112+
wq_offset = prev_group_end_offset;
113+
114+
// Output for 2d-2d grouped GEMM is shape (G, M, N)
115+
// output_offset -> group_index rows with stride of M * N
116+
output_offset = group_index * M * N;
117+
118+
// Group sizes are variable and converted to blocked/padded format, so
119+
// to calculate the starting offset of this group's scales, we do the
120+
// following: For each previous group
121+
// - Calculate the expected size of its blocked formatted scales
122+
// - Increment the scale offsets by that size
123+
// x_scale shape (M_rounded, total_K_padded_per_group).
124+
// w_scale has shape (N_rounded, total_K_padded_per_group).
125+
for (int i = 0; i < group_index; i++) {
126+
int group_i_size = i == 0 ? offsets[i] : offsets[i] - offsets[i - 1];
127+
int scale_cols_for_group_i_padded =
128+
round_up(group_i_size / scale_factor_block_size, 4);
129+
x_scale_offset += M_rounded * scale_cols_for_group_i_padded;
130+
w_scale_offset += N_rounded * scale_cols_for_group_i_padded;
131+
}
132+
133+
// Only write kernel args if this group is non-empty
134+
if (K_group_size > 0) {
135+
// Get index automatically for this group
136+
int total_K = K; // Name alias for clarity/readability.
137+
138+
// Set problem shape.
139+
// Main loop passes inputs in B,A order, so we have: (N, K_group) @
140+
// (M, K_group)^T = (N, M) for each group.
141+
problem_shape_ptr[group_index] = ProblemShape(N, M, K_group_size);
142+
143+
// Set pointers for this group.
144+
xq_ptr[group_index] = xq + xq_offset;
145+
wq_ptr[group_index] = wq + wq_offset;
146+
x_scale_ptr[group_index] = x_scale + x_scale_offset;
147+
w_scale_ptr[group_index] = w_scale + w_scale_offset;
148+
output_ptr[group_index] = output + output_offset;
149+
150+
// Set strides.
151+
// TODO: make strides configurable to handle all NT/TN/NN/NT layouts
152+
// that Blackwell supports. For XQ, the group processes a slice (M,
153+
// K_group_size) but it's part of a larger tensor (M, total_K). The
154+
// stride needs to reflect that rows are separated by total_K
155+
// elements in the original tensor.
156+
stride_a_ptr[group_index] = cutlass::make_cute_packed_stride(
157+
StrideA{}, cute::make_shape(int(M), int(total_K), 1));
158+
159+
// For WQ, the group processes a slice (N, K_group_size) but it's
160+
// part of a larger tensor (N, total_K). The stride needs to reflect
161+
// that rows are separated by total_K elements in the original
162+
// tensor.
163+
stride_b_ptr[group_index] = cutlass::make_cute_packed_stride(
164+
StrideB{}, cute::make_shape(int(N), int(total_K), 1));
165+
166+
// For output of this group, (M, K_group_size) @ (N, K_group_size)^T
167+
// = (M, N)
168+
stride_c_ptr[group_index] = cutlass::make_cute_packed_stride(
169+
StrideC{}, cute::make_shape(int(N), int(M), 1));
170+
171+
// Set layouts for scale factors.
172+
// Groups of variable size are along the K dim, so we need to
173+
// calculate the size of the blocked group scale factor here.
174+
layout_SFA[group_index] =
175+
Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(
176+
cute::make_shape(int(M), int(N), int(K_group_size), 1));
177+
layout_SFB[group_index] =
178+
Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(
179+
cute::make_shape(int(M), int(N), int(K_group_size), 1));
180+
}
181+
break;
182+
}
183+
case GroupedGemmInputType::_2D3D: {
184+
// `offsets` contains end index of each group.
185+
const int32_t prev_group_end_offset =
186+
(group_index == 0) ? 0 : offsets[group_index - 1];
187+
const int32_t curr_group_end_offset = offsets[group_index];
188+
const int32_t M_group_size =
189+
curr_group_end_offset - prev_group_end_offset;
190+
191+
if (M_group_size > 0) {
192+
// Validate group offsets.
193+
CUDA_KERNEL_ASSERT(
194+
curr_group_end_offset <= M &&
195+
"for 2d-3d grouped gemm, group end offsets must be non-negative and must be <= M\n");
196+
197+
// Compute starting offset for this group when M_group size > 0
198+
int64_t group_offset_M =
199+
group_index == 0 ? 0 : offsets[group_index - 1];
200+
int64_t scale_group_offset_M = 0;
201+
for (int i = 0; i < group_index; i++) {
202+
// Group offset on XQ along total_M dim is the sum of all previous
203+
// group sizes.
204+
int group_i_size =
205+
i == 0 ? offsets[i] : offsets[i] - offsets[i - 1];
206+
207+
// Scale group offset on x_scale is sum of all previous scale
208+
// group sizes.
209+
int scale_group_rows_padded = round_up(group_i_size, 128);
210+
scale_group_offset_M += scale_group_rows_padded;
211+
}
212+
213+
// wq_offset -> group_offset_M rows with stride of K
214+
xq_offset = group_offset_M * K;
215+
216+
// wq_offset -> group_index rows with stride of N * K (3d tensor)
217+
wq_offset = group_index * N * K;
218+
219+
// output_offset -> group_offset_M rows with stride of N
220+
output_offset = group_offset_M * N;
221+
222+
// x_scale offset -> sum of all padded group sizes (rows) * rounded
223+
// scale group cols
224+
const int64_t K_rounded = round_up(K / scale_factor_block_size, 4);
225+
x_scale_offset = scale_group_offset_M * K_rounded;
226+
227+
// w_scale_offset -> group_index rows with stride of (N rounded to
228+
// nearest multiple of 128 * K rounded to nearest multiple of 4)
229+
w_scale_offset = group_index * N_rounded * K_rounded;
230+
231+
// Set problem shape
232+
problem_shape_ptr[group_index] = ProblemShape(N, M_group_size, K);
233+
234+
// Set pointers
235+
xq_ptr[group_index] = xq + xq_offset;
236+
wq_ptr[group_index] = wq + wq_offset;
237+
x_scale_ptr[group_index] = x_scale + x_scale_offset;
238+
w_scale_ptr[group_index] = w_scale + w_scale_offset;
239+
output_ptr[group_index] = output + output_offset;
240+
241+
// Set strides
242+
stride_a_ptr[group_index] = cutlass::make_cute_packed_stride(
243+
StrideA{}, cute::make_shape(int(M_group_size), int(K), 1));
244+
stride_b_ptr[group_index] = cutlass::make_cute_packed_stride(
245+
StrideB{}, cute::make_shape(int(N), int(K), 1));
246+
stride_c_ptr[group_index] = cutlass::make_cute_packed_stride(
247+
StrideC{}, cute::make_shape(int(N), int(M_group_size), 1));
248+
249+
// Set layouts for scale factors
250+
layout_SFA[group_index] =
251+
Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(
252+
cute::make_shape(int(M_group_size), int(N), int(K), 1));
253+
layout_SFB[group_index] =
254+
Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(
255+
cute::make_shape(int(M_group_size), int(N), int(K), 1));
256+
}
257+
break;
258+
}
259+
}
260+
}
261+
}
262+
263+
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,6 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11-
#include <cutlass/util/device_memory.h>
12-
#include <cutlass/util/packed_stride.hpp>
13-
14-
// clang-format off
15-
// The fixed ordering of the headers is required for CUTLASS 3.2+
16-
#include <cute/tensor.hpp>
17-
#include <cutlass/gemm/collective/collective_builder.hpp> // @manual
18-
#include <cutlass/gemm/device/gemm_universal_adapter.h> // @manual
19-
#include <cutlass/epilogue/collective/collective_builder.hpp> // @manual
20-
// clang-format on
2111

2212
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
2313
#include "fbgemm_gpu/quantize/utils.h"

0 commit comments

Comments
 (0)