Skip to content

Commit 02ee7cc

Browse files
cthifacebook-github-bot
authored andcommitted
Support tuning cache for Cutlass BF16 grouped GEMM
Summary: This diff adds support for the tuning cache to the kernel. There should be no performance changes to the existing heuristics. - I refactored the kernel dispatch logic to instead return the kernel function, as it removes some duplication of the kernel invoke. - The next diff in this stack will add the new kernels D75806957, to make the review easier Reviewed By: q10 Differential Revision: D75541013
1 parent dc9b524 commit 02ee7cc

File tree

2 files changed

+107
-65
lines changed

2 files changed

+107
-65
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 90 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -8,61 +8,55 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <fmt/core.h>
1112

1213
#include "bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh"
14+
#include "fbgemm_gpu/quantize/common/tuning_cache.hpp"
15+
#include "fbgemm_gpu/quantize/common/utils.h"
1316

1417
namespace fbgemm_gpu {
1518

1619
#if CUDART_VERSION >= 12000
1720

18-
// BF16 grouped cutlass kernel dispatch.
21+
namespace {
22+
TuningCache& getTuningCache() {
23+
// This kernel has multiple APIs templated based on InputType, so we use this
24+
// to have a single cache instance across APIs.
25+
static TuningCache cache("bf16bf16bf16_grouped");
26+
return cache;
27+
}
28+
} // namespace
29+
1930
template <typename InputType>
20-
at::Tensor dispatch_bf16_grouped_kernel(
21-
int G,
22-
int total_M,
23-
int N,
24-
int K,
25-
InputType X, // BF16
26-
InputType W, // BF16
27-
at::Tensor output,
28-
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
29-
std::optional<at::Tensor> M_sizes = std::nullopt) {
31+
Kernel_bf16bf16bf16_grouped<InputType>
32+
get_kernel_via_heuristic(int G, int total_M, int N, int K) {
3033
// Use heuristics to pick best kernel implementation.
3134

3235
// Llama4 128E
3336
if (G == 128) {
3437
if (N == 5120 && K == 1024) {
3538
if (total_M <= 128) {
36-
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
37-
X, W, output, zero_start_index_M, M_sizes);
39+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
3840
} else if (total_M <= 256) {
39-
return bf16bf16bf16_grouped_128_32_128_2_1_1_t(
40-
X, W, output, zero_start_index_M, M_sizes);
41+
return bf16bf16bf16_grouped_128_32_128_2_1_1_t;
4142
} else if (total_M <= 2048) {
42-
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
43-
X, W, output, zero_start_index_M, M_sizes);
43+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
4444
} else if (total_M <= 4096) {
45-
return bf16bf16bf16_grouped_128_32_128_2_1_1_f(
46-
X, W, output, zero_start_index_M, M_sizes);
45+
return bf16bf16bf16_grouped_128_32_128_2_1_1_f;
4746
} else if (total_M <= 8192) {
48-
return bf16bf16bf16_grouped_128_64_128_1_1_1_f(
49-
X, W, output, zero_start_index_M, M_sizes);
47+
return bf16bf16bf16_grouped_128_64_128_1_1_1_f;
5048
} else if (total_M <= 16384) {
51-
return bf16bf16bf16_grouped_128_128_128_2_1_1_t(
52-
X, W, output, zero_start_index_M, M_sizes);
49+
return bf16bf16bf16_grouped_128_128_128_2_1_1_t;
5350
} else {
54-
return bf16bf16bf16_grouped_128_256_128_2_1_1_f(
55-
X, W, output, zero_start_index_M, M_sizes);
51+
return bf16bf16bf16_grouped_128_256_128_2_1_1_f;
5652
}
5753
}
5854

5955
if (N == 2048 && K == 5120) {
6056
if (total_M <= 2048) {
61-
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
62-
X, W, output, zero_start_index_M, M_sizes);
57+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
6358
} else {
64-
return bf16bf16bf16_grouped_128_128_128_2_1_1_t(
65-
X, W, output, zero_start_index_M, M_sizes);
59+
return bf16bf16bf16_grouped_128_128_128_2_1_1_t;
6660
}
6761
}
6862
}
@@ -71,71 +65,102 @@ at::Tensor dispatch_bf16_grouped_kernel(
7165
if (G == 16) {
7266
if (N == 5120 && K == 1024) {
7367
if (total_M <= 32) {
74-
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
75-
X, W, output, zero_start_index_M, M_sizes);
68+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
7669
} else if (total_M <= 64) {
77-
return bf16bf16bf16_grouped_128_32_128_2_1_1_t(
78-
X, W, output, zero_start_index_M, M_sizes);
70+
return bf16bf16bf16_grouped_128_32_128_2_1_1_t;
7971
} else if (total_M <= 256) {
80-
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
81-
X, W, output, zero_start_index_M, M_sizes);
72+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
8273
} else if (total_M <= 512) {
83-
return bf16bf16bf16_grouped_128_32_128_2_1_1_t(
84-
X, W, output, zero_start_index_M, M_sizes);
74+
return bf16bf16bf16_grouped_128_32_128_2_1_1_t;
8575
} else if (total_M <= 1024) {
86-
return bf16bf16bf16_grouped_128_64_128_2_1_1_t(
87-
X, W, output, zero_start_index_M, M_sizes);
76+
return bf16bf16bf16_grouped_128_64_128_2_1_1_t;
8877
} else {
89-
return bf16bf16bf16_grouped_128_256_128_2_1_1_f(
90-
X, W, output, zero_start_index_M, M_sizes);
78+
return bf16bf16bf16_grouped_128_256_128_2_1_1_f;
9179
}
9280
}
9381

9482
if (N == 2048 && K == 5120) {
9583
if (total_M <= 16) {
96-
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
97-
X, W, output, zero_start_index_M, M_sizes);
84+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
9885
} else if (total_M <= 64) {
99-
return bf16bf16bf16_grouped_128_32_128_2_1_1_f(
100-
X, W, output, zero_start_index_M, M_sizes);
86+
return bf16bf16bf16_grouped_128_32_128_2_1_1_f;
10187
} else if (total_M <= 256) {
102-
return bf16bf16bf16_grouped_128_16_128_2_1_1_f(
103-
X, W, output, zero_start_index_M, M_sizes);
88+
return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
10489
} else if (total_M <= 512) {
105-
return bf16bf16bf16_grouped_128_32_128_2_1_1_f(
106-
X, W, output, zero_start_index_M, M_sizes);
90+
return bf16bf16bf16_grouped_128_32_128_2_1_1_f;
10791
} else if (total_M <= 1024) {
108-
return bf16bf16bf16_grouped_128_64_128_1_1_1_f(
109-
X, W, output, zero_start_index_M, M_sizes);
92+
return bf16bf16bf16_grouped_128_64_128_1_1_1_f;
11093
} else {
111-
return bf16bf16bf16_grouped_128_128_128_2_1_1_t(
112-
X, W, output, zero_start_index_M, M_sizes);
94+
return bf16bf16bf16_grouped_128_128_128_2_1_1_t;
11395
}
11496
}
11597
}
11698

11799
// Fallback to legacy heuristic for now.
118100
if (total_M <= 16) {
119-
return bf16bf16bf16_grouped_128_16_128_1_1_1_f(
120-
X, W, output, zero_start_index_M, M_sizes);
101+
return bf16bf16bf16_grouped_128_16_128_1_1_1_f;
121102
} else if (total_M <= 32) {
122-
return bf16bf16bf16_grouped_128_32_128_1_1_1_f(
123-
X, W, output, zero_start_index_M, M_sizes);
103+
return bf16bf16bf16_grouped_128_32_128_1_1_1_f;
124104
} else if (total_M <= 64) {
125-
return bf16bf16bf16_grouped_128_64_128_1_1_1_f(
126-
X, W, output, zero_start_index_M, M_sizes);
105+
return bf16bf16bf16_grouped_128_64_128_1_1_1_f;
127106
} else if (total_M <= 128) {
128-
return bf16bf16bf16_grouped_128_128_128_1_1_1_f(
129-
X, W, output, zero_start_index_M, M_sizes);
107+
return bf16bf16bf16_grouped_128_128_128_1_1_1_f;
130108
} else if (total_M <= 512) {
131-
return bf16bf16bf16_grouped_256_128_128_2_1_1_f(
132-
X, W, output, zero_start_index_M, M_sizes);
109+
return bf16bf16bf16_grouped_256_128_128_2_1_1_f;
133110
} else {
134-
return bf16bf16bf16_grouped_128_256_128_2_1_1_f(
135-
X, W, output, zero_start_index_M, M_sizes);
111+
return bf16bf16bf16_grouped_128_256_128_2_1_1_f;
136112
}
137113
}
138114

115+
template <typename InputType>
116+
Kernel_bf16bf16bf16_grouped<InputType> get_kernel_via_tuning(
117+
int G,
118+
int total_M,
119+
int N,
120+
int K,
121+
InputType X, // BF16
122+
InputType W, // BF16
123+
at::Tensor output,
124+
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
125+
std::optional<at::Tensor> M_sizes = std::nullopt) {
126+
auto& cache = getTuningCache();
127+
128+
// Reducing amount of auto tuning by rounding up total_m to next power of 2.
129+
total_M = nextPowerOf2(total_M);
130+
// Use (total_M, N, K, G) shape as the key.
131+
const std::string shape_key = fmt::format("{}_{}_{}_{}", total_M, N, K, G);
132+
const auto& kernels = get_bf16bf16bf16_grouped_kernels<InputType>();
133+
auto kernel = cache.findBestKernelMaybeAutotune(
134+
shape_key, kernels, X, W, output, zero_start_index_M, M_sizes);
135+
136+
return kernel;
137+
}
138+
139+
// BF16 grouped cutlass kernel dispatch.
140+
template <typename InputType>
141+
at::Tensor dispatch_bf16_grouped_kernel(
142+
int G,
143+
int total_M,
144+
int N,
145+
int K,
146+
InputType X, // BF16
147+
InputType W, // BF16
148+
at::Tensor output,
149+
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
150+
std::optional<at::Tensor> M_sizes = std::nullopt) {
151+
// Select kernel to run via heuristics or tuning.
152+
auto kernel = [&]() {
153+
if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
154+
return get_kernel_via_tuning(
155+
G, total_M, N, K, X, W, output, zero_start_index_M, M_sizes);
156+
} else {
157+
return get_kernel_via_heuristic<InputType>(G, total_M, N, K);
158+
}
159+
}();
160+
// Invoke kernel
161+
return kernel(X, W, output, zero_start_index_M, M_sizes);
162+
}
163+
139164
template <typename OutputType>
140165
OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) {
141166
at::Tensor Y;

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,4 +180,21 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_f(
180180
std::optional<at::Tensor> zero_start_index_M,
181181
std::optional<at::Tensor> M_sizes);
182182

183+
template <typename InputType>
184+
using Kernel_bf16bf16bf16_grouped = at::Tensor (*)(
185+
InputType,
186+
InputType,
187+
at::Tensor,
188+
std::optional<at::Tensor>,
189+
std::optional<at::Tensor>);
190+
191+
template <typename InputType>
192+
const std::unordered_map<std::string, Kernel_bf16bf16bf16_grouped<InputType>>&
193+
get_bf16bf16bf16_grouped_kernels() {
194+
static const std::
195+
unordered_map<std::string, Kernel_bf16bf16bf16_grouped<InputType>>
196+
kernels = {};
197+
return kernels;
198+
}
199+
183200
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)