|
10 | 10 | #include <ATen/cuda/CUDAContext.h>
|
11 | 11 | #include <c10/cuda/CUDAGuard.h>
|
12 | 12 | // clang-format on
|
| 13 | +#include <fmt/core.h> |
13 | 14 |
|
14 | 15 | #include "f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh"
|
| 16 | +#include "fbgemm_gpu/quantize/common/tuning_cache.hpp" |
| 17 | +#include "fbgemm_gpu/quantize/common/utils.h" |
15 | 18 |
|
16 | 19 | namespace fbgemm_gpu {
|
17 | 20 |
|
18 | 21 | #if CUDART_VERSION >= 12000
|
19 | 22 |
|
20 | 23 | // FP8 Rowwise Cutlass kernel dispatch.
|
21 |
| -at::Tensor dispatch_fp8_rowwise_kernel( |
22 |
| - at::Tensor XQ, |
23 |
| - at::Tensor WQ, |
24 |
| - at::Tensor x_scale, |
25 |
| - at::Tensor w_scale, |
26 |
| - bool use_fast_accum, |
27 |
| - std::optional<at::Tensor> bias = std::nullopt, |
28 |
| - std::optional<at::Tensor> output = std::nullopt) { |
29 |
| - int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); |
30 |
| - int N = size_to_dim_(WQ.dim() - 1, WQ.sizes()); |
31 |
| - int K = XQ.size(-1); |
32 |
| - static int arch = -1; |
33 |
| - // Avoid expensive cudaGetDeviceProperties call. |
34 |
| - if (arch < 0) { |
35 |
| - cudaDeviceProp prop; |
36 |
| - cudaGetDeviceProperties(&prop, 0); |
37 |
| - if (prop.major >= 10) { |
38 |
| - arch = 10; |
39 |
| - int runtimeVersion; |
40 |
| - C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion)); |
41 |
| - TORCH_CHECK( |
42 |
| - runtimeVersion >= 12080, |
43 |
| - "FP8 GEMM on sm100a or above requires cuda >= 12.8"); |
44 |
| - } else { |
45 |
| - arch = 9; |
46 |
| - } |
47 |
| - } |
48 |
| - |
| 24 | +Kernel_f8f8bf16_rowwise |
| 25 | +get_kernel_via_heuristic(int arch, int M, int N, int K, bool use_fast_accum) { |
49 | 26 | // Use shape heuristics to dispatch to optimized kernel configuration.
|
50 | 27 | if (arch == 10) {
|
51 | 28 | if (M <= 128) {
|
52 | 29 | if (N <= 1024) {
|
53 |
| - return f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f( |
54 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 30 | + return f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f; |
55 | 31 | } else {
|
56 |
| - return f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f( |
57 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 32 | + return f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f; |
58 | 33 | }
|
59 | 34 | } else if (M <= 1024) {
|
60 | 35 | if (N <= 1024) {
|
61 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( |
62 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 36 | + return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f; |
63 | 37 | } else {
|
64 |
| - return f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f( |
65 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 38 | + return f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f; |
66 | 39 | }
|
67 | 40 | } else if (M <= 2048) {
|
68 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( |
69 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 41 | + return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f; |
70 | 42 | } else {
|
71 | 43 | if (N <= 1024) {
|
72 |
| - return f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f( |
73 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 44 | + return f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f; |
74 | 45 | } else {
|
75 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( |
76 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 46 | + return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f; |
77 | 47 | }
|
78 | 48 | }
|
79 | 49 | } else {
|
80 | 50 | if (M <= 16) {
|
81 |
| - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( |
82 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 51 | + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; |
83 | 52 | } else if (M <= 32) {
|
84 | 53 | if (N <= 4096) {
|
85 |
| - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( |
86 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 54 | + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; |
87 | 55 | } else {
|
88 |
| - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( |
89 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 56 | + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; |
90 | 57 | }
|
91 | 58 | } else if (M <= 64) {
|
92 | 59 | if (N <= 2048) {
|
93 |
| - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( |
94 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 60 | + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; |
95 | 61 | } else if (N <= 4096) {
|
96 |
| - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( |
97 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 62 | + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; |
98 | 63 | } else {
|
99 |
| - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( |
100 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 64 | + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; |
101 | 65 | }
|
102 | 66 | } else if (M <= 128) {
|
103 | 67 | if (N <= 1024) {
|
104 |
| - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( |
105 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 68 | + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; |
106 | 69 | } else if (N <= 2048) {
|
107 |
| - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( |
108 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 70 | + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; |
109 | 71 | } else if (N <= 4096) {
|
110 |
| - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( |
111 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 72 | + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; |
112 | 73 | } else {
|
113 |
| - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( |
114 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 74 | + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; |
115 | 75 | }
|
116 | 76 | } else if (M <= 256) {
|
117 | 77 | if (N <= 1024) {
|
118 |
| - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( |
119 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 78 | + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; |
120 | 79 | } else if (N <= 2048) {
|
121 |
| - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( |
122 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 80 | + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; |
123 | 81 | } else if (N <= 4096) {
|
124 |
| - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( |
125 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 82 | + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; |
126 | 83 | } else {
|
127 |
| - return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f( |
128 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 84 | + return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f; |
129 | 85 | }
|
130 | 86 | } else if (M <= 512) {
|
131 | 87 | if (N <= 1024) {
|
132 |
| - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( |
133 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 88 | + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; |
134 | 89 | } else if (N <= 2048) {
|
135 |
| - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( |
136 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 90 | + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; |
137 | 91 | } else if (N <= 4096 || use_fast_accum == false) {
|
138 |
| - return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f( |
139 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 92 | + return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f; |
140 | 93 | } else {
|
141 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t( |
142 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 94 | + return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t; |
143 | 95 | }
|
144 | 96 | } else if (M <= 1024) {
|
145 | 97 | if (N <= 1024) {
|
146 |
| - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( |
147 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 98 | + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; |
148 | 99 | } else if (N <= 2048 || use_fast_accum == false) {
|
149 |
| - return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f( |
150 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 100 | + return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f; |
151 | 101 | } else {
|
152 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t( |
153 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 102 | + return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t; |
154 | 103 | }
|
155 | 104 | } else {
|
156 | 105 | if (M <= 2048 && N <= 1024) {
|
157 |
| - return f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f( |
158 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 106 | + return f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f; |
159 | 107 | } else if (K <= 4096 || use_fast_accum == false) {
|
160 |
| - return f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f( |
161 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 108 | + return f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f; |
162 | 109 | } else if (M > 8192 && N > 8192) {
|
163 |
| - return f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t( |
164 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 110 | + return f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t; |
165 | 111 | } else {
|
166 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t( |
167 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 112 | + return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t; |
168 | 113 | }
|
169 | 114 | }
|
170 | 115 | }
|
171 | 116 | }
|
172 | 117 |
|
| 118 | +Kernel_f8f8bf16_rowwise get_kernel_via_tuning( |
| 119 | + int arch, |
| 120 | + int M, |
| 121 | + int N, |
| 122 | + int K, |
| 123 | + at::Tensor XQ, |
| 124 | + at::Tensor WQ, |
| 125 | + at::Tensor x_scale, |
| 126 | + at::Tensor w_scale, |
| 127 | + bool use_fast_accum, |
| 128 | + std::optional<at::Tensor> bias = std::nullopt, |
| 129 | + std::optional<at::Tensor> output = std::nullopt) { |
| 130 | + // One cache per kernel type |
| 131 | + static TuningCache cache("f8f8bf16_rowwise"); |
| 132 | + |
| 133 | + // Reducing amount of auto tuning by rounding up M to next power of 2. |
| 134 | + M = nextPowerOf2(M); |
| 135 | + // Use (M, N, K) shape as the key. |
| 136 | + const std::string shape_key = fmt::format("{}_{}_{}", M, N, K); |
| 137 | + const auto& kernels = get_f8f8bf16_rowwise_kernels(arch); |
| 138 | + auto kernel = cache.findBestKernelMaybeAutotune( |
| 139 | + shape_key, |
| 140 | + kernels, |
| 141 | + XQ, |
| 142 | + WQ, |
| 143 | + x_scale, |
| 144 | + w_scale, |
| 145 | + use_fast_accum, |
| 146 | + bias, |
| 147 | + output); |
| 148 | + |
| 149 | + return kernel; |
| 150 | +} |
| 151 | + |
| 152 | +// FP8 Rowwise Cutlass kernel dispatch. |
| 153 | +at::Tensor dispatch_fp8_rowwise_kernel( |
| 154 | + at::Tensor XQ, |
| 155 | + at::Tensor WQ, |
| 156 | + at::Tensor x_scale, |
| 157 | + at::Tensor w_scale, |
| 158 | + bool use_fast_accum, |
| 159 | + std::optional<at::Tensor> bias = std::nullopt, |
| 160 | + std::optional<at::Tensor> output = std::nullopt) { |
| 161 | + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); |
| 162 | + int N = size_to_dim_(WQ.dim() - 1, WQ.sizes()); |
| 163 | + int K = XQ.size(-1); |
| 164 | + |
| 165 | + static int arch = -1; |
| 166 | + // Avoid expensive cudaGetDeviceProperties call. |
| 167 | + if (arch < 0) { |
| 168 | + cudaDeviceProp prop; |
| 169 | + cudaGetDeviceProperties(&prop, 0); |
| 170 | + if (prop.major >= 10) { |
| 171 | + arch = 10; |
| 172 | + int runtimeVersion; |
| 173 | + C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion)); |
| 174 | + TORCH_CHECK( |
| 175 | + runtimeVersion >= 12080, |
| 176 | + "FP8 GEMM on sm100a or above requires cuda >= 12.8"); |
| 177 | + } else { |
| 178 | + arch = 9; |
| 179 | + } |
| 180 | + } |
| 181 | + |
| 182 | + // Select kernel to run via heuristics or tuning. |
| 183 | + auto kernel = [&]() { |
| 184 | + if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) { |
| 185 | + return get_kernel_via_tuning( |
| 186 | + arch, |
| 187 | + M, |
| 188 | + N, |
| 189 | + K, |
| 190 | + XQ, |
| 191 | + WQ, |
| 192 | + x_scale, |
| 193 | + w_scale, |
| 194 | + use_fast_accum, |
| 195 | + bias, |
| 196 | + output); |
| 197 | + } else { |
| 198 | + return get_kernel_via_heuristic(arch, M, N, K, use_fast_accum); |
| 199 | + } |
| 200 | + }(); |
| 201 | + // Invoke kernel |
| 202 | + return kernel(XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 203 | +} |
| 204 | + |
173 | 205 | void f8f8bf16_rowwise_out(
|
174 | 206 | at::Tensor XQ, // FP8
|
175 | 207 | at::Tensor WQ, // FP8
|
|
0 commit comments