Skip to content

Commit e355bec

Browse files
cthifacebook-github-bot
authored andcommitted
Remove e5m2 from f8f8bf16_rowwise (#4907)
Summary: Pull Request resolved: #4907 X-link: facebookresearch/FBGEMM#1931 - e5m2 is not used, remove it to save code size and avoid more build issues - possibly e5m2 was messed up to begin with, we hardcoded `using ElementB = cutlass::float_e4m3_t` Reviewed By: jiawenliu64, jwfromm Differential Revision: D82965586 fbshipit-source-id: 61e61bd316bc0664cdbca8aa761e44f1232d0ee2
1 parent 8ec3635 commit e355bec

File tree

3 files changed

+59
-154
lines changed

3 files changed

+59
-154
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@
1414
#include "f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh"
1515
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
1616
#include "fbgemm_gpu/quantize/utils.h"
17+
#include "fbgemm_gpu/quantize/utils_gpu.h"
1718

1819
namespace fbgemm_gpu {
1920

2021
#if CUDART_VERSION >= 12000
2122

2223
// FP8 Rowwise Cutlass kernel dispatch.
2324
Kernel_f8f8bf16_rowwise
24-
get_kernel_via_heuristic(int arch, int M, int N, int K, bool use_fast_accum) {
25+
get_kernel_via_heuristic(int M, int N, int K, bool use_fast_accum) {
2526
// Use shape heuristics to dispatch to optimized kernel configuration.
27+
const int arch = getDeviceArch();
28+
2629
if (arch == 10) {
2730
if (M <= 128) {
2831
if (N <= 1024) {
@@ -115,7 +118,6 @@ get_kernel_via_heuristic(int arch, int M, int N, int K, bool use_fast_accum) {
115118
}
116119

117120
Kernel_f8f8bf16_rowwise get_kernel_via_tuning(
118-
int arch,
119121
int M,
120122
int N,
121123
int K,
@@ -134,6 +136,7 @@ Kernel_f8f8bf16_rowwise get_kernel_via_tuning(
134136
// Use (M, N, K) shape as the key.
135137
const std::string shape_key =
136138
std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K);
139+
const int arch = getDeviceArch();
137140
const auto& kernels = get_f8f8bf16_rowwise_kernels(arch);
138141
auto kernel = cache.findBestKernelMaybeAutotune(
139142
shape_key,
@@ -158,44 +161,19 @@ at::Tensor dispatch_fp8_rowwise_kernel(
158161
bool use_fast_accum,
159162
std::optional<at::Tensor> bias = std::nullopt,
160163
std::optional<at::Tensor> output = std::nullopt) {
164+
TORCH_CHECK(XQ.dtype() == at::kFloat8_e4m3fn);
165+
161166
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
162167
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
163168
int K = XQ.size(-1);
164169

165-
static int arch = -1;
166-
// Avoid expensive cudaGetDeviceProperties call.
167-
if (arch < 0) {
168-
cudaDeviceProp prop;
169-
cudaGetDeviceProperties(&prop, 0);
170-
if (prop.major >= 10) {
171-
arch = 10;
172-
int runtimeVersion;
173-
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
174-
TORCH_CHECK(
175-
runtimeVersion >= 12080,
176-
"FP8 GEMM on sm100a or above requires cuda >= 12.8");
177-
} else {
178-
arch = 9;
179-
}
180-
}
181-
182170
// Select kernel to run via heuristics or tuning.
183171
auto kernel = [&]() {
184172
if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
185173
return get_kernel_via_tuning(
186-
arch,
187-
M,
188-
N,
189-
K,
190-
XQ,
191-
WQ,
192-
x_scale,
193-
w_scale,
194-
use_fast_accum,
195-
bias,
196-
output);
174+
M, N, K, XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
197175
} else {
198-
return get_kernel_via_heuristic(arch, M, N, K, use_fast_accum);
176+
return get_kernel_via_heuristic(M, N, K, use_fast_accum);
199177
}
200178
}();
201179
// Invoke kernel

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_common.cuh

Lines changed: 49 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ template <
3636
bool PONG,
3737
bool COOP,
3838
bool FAST_ACCUM,
39-
typename INPUT_DTYPE,
4039
typename BIAS_DTYPE>
4140
at::Tensor f8f8bf16_rowwise_impl(
4241
at::Tensor XQ, // FP8
@@ -76,7 +75,7 @@ at::Tensor f8f8bf16_rowwise_impl(
7675
Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16));
7776
}
7877

79-
using ElementA = INPUT_DTYPE;
78+
using ElementA = cutlass::float_e4m3_t;
8079
using LayoutA = cutlass::layout::RowMajor;
8180
constexpr int AlignmentInputA = 16 / sizeof(ElementA);
8281

@@ -351,131 +350,61 @@ at::Tensor f8f8bf16_rowwise_wrapper(
351350
bool bf16_bias = bias.has_value() && bias.value().dtype() == at::kBFloat16;
352351

353352
// Templatize based on input dtype.
354-
bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2;
355-
356353
if (bf16_bias) {
357354
if (use_fast_accum) {
358-
if (use_e5m2) {
359-
return f8f8bf16_rowwise_impl<
360-
TB_M,
361-
TB_N,
362-
TB_K,
363-
TBS_M,
364-
TBS_N,
365-
TBS_K,
366-
ARCH,
367-
PONG,
368-
COOP,
369-
true,
370-
cutlass::float_e5m2_t,
371-
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
372-
} else {
373-
return f8f8bf16_rowwise_impl<
374-
TB_M,
375-
TB_N,
376-
TB_K,
377-
TBS_M,
378-
TBS_N,
379-
TBS_K,
380-
ARCH,
381-
PONG,
382-
COOP,
383-
true,
384-
cutlass::float_e4m3_t,
385-
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
386-
}
355+
return f8f8bf16_rowwise_impl<
356+
TB_M,
357+
TB_N,
358+
TB_K,
359+
TBS_M,
360+
TBS_N,
361+
TBS_K,
362+
ARCH,
363+
PONG,
364+
COOP,
365+
true,
366+
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
387367
} else {
388-
if (use_e5m2) {
389-
return f8f8bf16_rowwise_impl<
390-
TB_M,
391-
TB_N,
392-
TB_K,
393-
TBS_M,
394-
TBS_N,
395-
TBS_K,
396-
ARCH,
397-
PONG,
398-
COOP,
399-
false,
400-
cutlass::float_e5m2_t,
401-
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
402-
} else {
403-
return f8f8bf16_rowwise_impl<
404-
TB_M,
405-
TB_N,
406-
TB_K,
407-
TBS_M,
408-
TBS_N,
409-
TBS_K,
410-
ARCH,
411-
PONG,
412-
COOP,
413-
false,
414-
cutlass::float_e4m3_t,
415-
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
416-
}
368+
return f8f8bf16_rowwise_impl<
369+
TB_M,
370+
TB_N,
371+
TB_K,
372+
TBS_M,
373+
TBS_N,
374+
TBS_K,
375+
ARCH,
376+
PONG,
377+
COOP,
378+
false,
379+
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
417380
}
418381
} else {
419382
if (use_fast_accum) {
420-
if (use_e5m2) {
421-
return f8f8bf16_rowwise_impl<
422-
TB_M,
423-
TB_N,
424-
TB_K,
425-
TBS_M,
426-
TBS_N,
427-
TBS_K,
428-
ARCH,
429-
PONG,
430-
COOP,
431-
true,
432-
cutlass::float_e5m2_t,
433-
float>(XQ, WQ, x_scale, w_scale, bias, output);
434-
} else {
435-
return f8f8bf16_rowwise_impl<
436-
TB_M,
437-
TB_N,
438-
TB_K,
439-
TBS_M,
440-
TBS_N,
441-
TBS_K,
442-
ARCH,
443-
PONG,
444-
COOP,
445-
true,
446-
cutlass::float_e4m3_t,
447-
float>(XQ, WQ, x_scale, w_scale, bias, output);
448-
}
383+
return f8f8bf16_rowwise_impl<
384+
TB_M,
385+
TB_N,
386+
TB_K,
387+
TBS_M,
388+
TBS_N,
389+
TBS_K,
390+
ARCH,
391+
PONG,
392+
COOP,
393+
true,
394+
float>(XQ, WQ, x_scale, w_scale, bias, output);
449395
} else {
450-
if (use_e5m2) {
451-
return f8f8bf16_rowwise_impl<
452-
TB_M,
453-
TB_N,
454-
TB_K,
455-
TBS_M,
456-
TBS_N,
457-
TBS_K,
458-
ARCH,
459-
PONG,
460-
COOP,
461-
false,
462-
cutlass::float_e5m2_t,
463-
float>(XQ, WQ, x_scale, w_scale, bias, output);
464-
} else {
465-
return f8f8bf16_rowwise_impl<
466-
TB_M,
467-
TB_N,
468-
TB_K,
469-
TBS_M,
470-
TBS_N,
471-
TBS_K,
472-
ARCH,
473-
PONG,
474-
COOP,
475-
false,
476-
cutlass::float_e4m3_t,
477-
float>(XQ, WQ, x_scale, w_scale, bias, output);
478-
}
396+
return f8f8bf16_rowwise_impl<
397+
TB_M,
398+
TB_N,
399+
TB_K,
400+
TBS_M,
401+
TBS_N,
402+
TBS_K,
403+
ARCH,
404+
PONG,
405+
COOP,
406+
false,
407+
float>(XQ, WQ, x_scale, w_scale, bias, output);
479408
}
480409
}
481410
}

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None:
349349
["rowwise", "blockwise"]
350350
+ (["tensorwise_broadcast", "tensorwise"] if torch.version.cuda else [])
351351
),
352-
QType=(
353-
st.sampled_from([fp8_e4m3, fp8_e5m2] if torch.version.cuda else [fp8_e4m3])
354-
),
352+
QType=(st.sampled_from([fp8_e4m3])),
355353
Bias=st.sampled_from([True, False]),
356354
CudaGraph=st.sampled_from([True, False]),
357355
UseTriton=st.sampled_from([False] + ([True] if torch.version.cuda else [])),

0 commit comments

Comments
 (0)