Skip to content

Commit ccbdff4

Browse files
cthifacebook-github-bot
authored andcommitted
Support tuning cache for Cutlass FP8 GEMM (#4301)
Summary: Pull Request resolved: #4301 X-link: facebookresearch/FBGEMM#1377 This diff adds support for the tuning cache to the kernel. There should be no performance changes to the existing heuristics. - I refactored the kernel dispatch logic to instead return the kernel function, as it removes some duplication of the kernel invoke. - The next diff in this stack will add the new kernels D75820688, to make the review easier - Note that we are having some issues with adding the new kernels, as I have found this kernel is actually compiling 12 variants for each configuration, see D75820688 for more context. So for now we won't add the new kernels in D75820688, but we can just onboard it to auto tuning incase someone wants to compile them locally. Will revisit D75820688 later. Reviewed By: q10, jiawenliu64 Differential Revision: D75541025
1 parent 9086c6e commit ccbdff4

File tree

2 files changed

+147
-92
lines changed

2 files changed

+147
-92
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu

Lines changed: 124 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -10,166 +10,198 @@
1010
#include <ATen/cuda/CUDAContext.h>
1111
#include <c10/cuda/CUDAGuard.h>
1212
// clang-format on
13+
#include <fmt/core.h>
1314

1415
#include "f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh"
16+
#include "fbgemm_gpu/quantize/common/tuning_cache.hpp"
17+
#include "fbgemm_gpu/quantize/common/utils.h"
1518

1619
namespace fbgemm_gpu {
1720

1821
#if CUDART_VERSION >= 12000
1922

2023
// FP8 Rowwise Cutlass kernel dispatch.
21-
at::Tensor dispatch_fp8_rowwise_kernel(
22-
at::Tensor XQ,
23-
at::Tensor WQ,
24-
at::Tensor x_scale,
25-
at::Tensor w_scale,
26-
bool use_fast_accum,
27-
std::optional<at::Tensor> bias = std::nullopt,
28-
std::optional<at::Tensor> output = std::nullopt) {
29-
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
30-
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
31-
int K = XQ.size(-1);
32-
static int arch = -1;
33-
// Avoid expensive cudaGetDeviceProperties call.
34-
if (arch < 0) {
35-
cudaDeviceProp prop;
36-
cudaGetDeviceProperties(&prop, 0);
37-
if (prop.major >= 10) {
38-
arch = 10;
39-
int runtimeVersion;
40-
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
41-
TORCH_CHECK(
42-
runtimeVersion >= 12080,
43-
"FP8 GEMM on sm100a or above requires cuda >= 12.8");
44-
} else {
45-
arch = 9;
46-
}
47-
}
48-
24+
Kernel_f8f8bf16_rowwise
25+
get_kernel_via_heuristic(int arch, int M, int N, int K, bool use_fast_accum) {
4926
// Use shape heuristics to dispatch to optimized kernel configuration.
5027
if (arch == 10) {
5128
if (M <= 128) {
5229
if (N <= 1024) {
53-
return f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f(
54-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
30+
return f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f;
5531
} else {
56-
return f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f(
57-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
32+
return f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f;
5833
}
5934
} else if (M <= 1024) {
6035
if (N <= 1024) {
61-
return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f(
62-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
36+
return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f;
6337
} else {
64-
return f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f(
65-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
38+
return f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f;
6639
}
6740
} else if (M <= 2048) {
68-
return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f(
69-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
41+
return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f;
7042
} else {
7143
if (N <= 1024) {
72-
return f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f(
73-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
44+
return f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f;
7445
} else {
75-
return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f(
76-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
46+
return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f;
7747
}
7848
}
7949
} else {
8050
if (M <= 16) {
81-
return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f(
82-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
51+
return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f;
8352
} else if (M <= 32) {
8453
if (N <= 4096) {
85-
return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f(
86-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
54+
return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f;
8755
} else {
88-
return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f(
89-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
56+
return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f;
9057
}
9158
} else if (M <= 64) {
9259
if (N <= 2048) {
93-
return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f(
94-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
60+
return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f;
9561
} else if (N <= 4096) {
96-
return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f(
97-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
62+
return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f;
9863
} else {
99-
return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f(
100-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
64+
return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f;
10165
}
10266
} else if (M <= 128) {
10367
if (N <= 1024) {
104-
return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f(
105-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
68+
return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f;
10669
} else if (N <= 2048) {
107-
return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f(
108-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
70+
return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f;
10971
} else if (N <= 4096) {
110-
return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f(
111-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
72+
return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f;
11273
} else {
113-
return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f(
114-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
74+
return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f;
11575
}
11676
} else if (M <= 256) {
11777
if (N <= 1024) {
118-
return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f(
119-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
78+
return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f;
12079
} else if (N <= 2048) {
121-
return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f(
122-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
80+
return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f;
12381
} else if (N <= 4096) {
124-
return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f(
125-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
82+
return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f;
12683
} else {
127-
return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f(
128-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
84+
return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f;
12985
}
13086
} else if (M <= 512) {
13187
if (N <= 1024) {
132-
return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f(
133-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
88+
return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f;
13489
} else if (N <= 2048) {
135-
return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f(
136-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
90+
return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f;
13791
} else if (N <= 4096 || use_fast_accum == false) {
138-
return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f(
139-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
92+
return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f;
14093
} else {
141-
return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t(
142-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
94+
return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t;
14395
}
14496
} else if (M <= 1024) {
14597
if (N <= 1024) {
146-
return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f(
147-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
98+
return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f;
14899
} else if (N <= 2048 || use_fast_accum == false) {
149-
return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f(
150-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
100+
return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f;
151101
} else {
152-
return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t(
153-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
102+
return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t;
154103
}
155104
} else {
156105
if (M <= 2048 && N <= 1024) {
157-
return f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f(
158-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
106+
return f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f;
159107
} else if (K <= 4096 || use_fast_accum == false) {
160-
return f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f(
161-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
108+
return f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f;
162109
} else if (M > 8192 && N > 8192) {
163-
return f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t(
164-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
110+
return f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t;
165111
} else {
166-
return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t(
167-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
112+
return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t;
168113
}
169114
}
170115
}
171116
}
172117

118+
Kernel_f8f8bf16_rowwise get_kernel_via_tuning(
119+
int arch,
120+
int M,
121+
int N,
122+
int K,
123+
at::Tensor XQ,
124+
at::Tensor WQ,
125+
at::Tensor x_scale,
126+
at::Tensor w_scale,
127+
bool use_fast_accum,
128+
std::optional<at::Tensor> bias = std::nullopt,
129+
std::optional<at::Tensor> output = std::nullopt) {
130+
// One cache per kernel type
131+
static TuningCache cache("f8f8bf16_rowwise");
132+
133+
// Reducing amount of auto tuning by rounding up M to next power of 2.
134+
M = nextPowerOf2(M);
135+
// Use (M, N, K) shape as the key.
136+
const std::string shape_key = fmt::format("{}_{}_{}", M, N, K);
137+
const auto& kernels = get_f8f8bf16_rowwise_kernels(arch);
138+
auto kernel = cache.findBestKernelMaybeAutotune(
139+
shape_key,
140+
kernels,
141+
XQ,
142+
WQ,
143+
x_scale,
144+
w_scale,
145+
use_fast_accum,
146+
bias,
147+
output);
148+
149+
return kernel;
150+
}
151+
152+
// FP8 Rowwise Cutlass kernel dispatch.
153+
at::Tensor dispatch_fp8_rowwise_kernel(
154+
at::Tensor XQ,
155+
at::Tensor WQ,
156+
at::Tensor x_scale,
157+
at::Tensor w_scale,
158+
bool use_fast_accum,
159+
std::optional<at::Tensor> bias = std::nullopt,
160+
std::optional<at::Tensor> output = std::nullopt) {
161+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
162+
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
163+
int K = XQ.size(-1);
164+
165+
static int arch = -1;
166+
// Avoid expensive cudaGetDeviceProperties call.
167+
if (arch < 0) {
168+
cudaDeviceProp prop;
169+
cudaGetDeviceProperties(&prop, 0);
170+
if (prop.major >= 10) {
171+
arch = 10;
172+
int runtimeVersion;
173+
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
174+
TORCH_CHECK(
175+
runtimeVersion >= 12080,
176+
"FP8 GEMM on sm100a or above requires cuda >= 12.8");
177+
} else {
178+
arch = 9;
179+
}
180+
}
181+
182+
// Select kernel to run via heuristics or tuning.
183+
auto kernel = [&]() {
184+
if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
185+
return get_kernel_via_tuning(
186+
arch,
187+
M,
188+
N,
189+
K,
190+
XQ,
191+
WQ,
192+
x_scale,
193+
w_scale,
194+
use_fast_accum,
195+
bias,
196+
output);
197+
} else {
198+
return get_kernel_via_heuristic(arch, M, N, K, use_fast_accum);
199+
}
200+
}();
201+
// Invoke kernel
202+
return kernel(XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
203+
}
204+
173205
void f8f8bf16_rowwise_out(
174206
at::Tensor XQ, // FP8
175207
at::Tensor WQ, // FP8

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,27 @@ at::Tensor f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f(
135135
bool use_fast_accum = true,
136136
std::optional<at::Tensor> bias = std::nullopt,
137137
std::optional<at::Tensor> output = std::nullopt);
138+
139+
using Kernel_f8f8bf16_rowwise = at::Tensor (*)(
140+
at::Tensor,
141+
at::Tensor,
142+
at::Tensor,
143+
at::Tensor,
144+
bool,
145+
std::optional<at::Tensor>,
146+
std::optional<at::Tensor>);
147+
148+
inline const std::unordered_map<std::string, Kernel_f8f8bf16_rowwise>&
149+
get_f8f8bf16_rowwise_kernels(int arch) {
150+
static const std::unordered_map<std::string, Kernel_f8f8bf16_rowwise>
151+
kernelsSM90 = {};
152+
static const std::unordered_map<std::string, Kernel_f8f8bf16_rowwise>
153+
kernelsSM100 = {};
154+
if (arch == 10) {
155+
return kernelsSM100;
156+
} else {
157+
return kernelsSM90;
158+
}
159+
}
160+
138161
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)