|
12 | 12 | // clang-format on
|
13 | 13 |
|
14 | 14 | #include "f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh"
|
| 15 | +#include "fbgemm_gpu/quantize/tuning_cache.hpp" |
| 16 | +#include "fbgemm_gpu/quantize/utils.h" |
15 | 17 |
|
16 | 18 | namespace fbgemm_gpu {
|
17 | 19 |
|
18 | 20 | #if CUDART_VERSION >= 12000
|
19 | 21 |
|
20 | 22 | // FP8 Rowwise Cutlass kernel dispatch.
|
21 |
| -at::Tensor dispatch_fp8_rowwise_kernel( |
22 |
| - at::Tensor XQ, |
23 |
| - at::Tensor WQ, |
24 |
| - at::Tensor x_scale, |
25 |
| - at::Tensor w_scale, |
26 |
| - bool use_fast_accum, |
27 |
| - std::optional<at::Tensor> bias = std::nullopt, |
28 |
| - std::optional<at::Tensor> output = std::nullopt) { |
29 |
| - int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); |
30 |
| - int N = size_to_dim_(WQ.dim() - 1, WQ.sizes()); |
31 |
| - int K = XQ.size(-1); |
32 |
| - static int arch = -1; |
33 |
| - // Avoid expensive cudaGetDeviceProperties call. |
34 |
| - if (arch < 0) { |
35 |
| - cudaDeviceProp prop; |
36 |
| - cudaGetDeviceProperties(&prop, 0); |
37 |
| - if (prop.major >= 10) { |
38 |
| - arch = 10; |
39 |
| - int runtimeVersion; |
40 |
| - C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion)); |
41 |
| - TORCH_CHECK( |
42 |
| - runtimeVersion >= 12080, |
43 |
| - "FP8 GEMM on sm100a or above requires cuda >= 12.8"); |
44 |
| - } else { |
45 |
| - arch = 9; |
46 |
| - } |
47 |
| - } |
48 |
| - |
| 23 | +Kernel_f8f8bf16_rowwise |
| 24 | +get_kernel_via_heuristic(int arch, int M, int N, int K, bool use_fast_accum) { |
49 | 25 | // Use shape heuristics to dispatch to optimized kernel configuration.
|
50 | 26 | if (arch == 10) {
|
51 | 27 | if (M <= 128) {
|
52 | 28 | if (N <= 1024) {
|
53 |
| - return f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f( |
54 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 29 | + return f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f; |
55 | 30 | } else {
|
56 |
| - return f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f( |
57 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 31 | + return f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f; |
58 | 32 | }
|
59 | 33 | } else if (M <= 1024) {
|
60 | 34 | if (N <= 1024) {
|
61 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( |
62 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 35 | + return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f; |
63 | 36 | } else {
|
64 |
| - return f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f( |
65 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 37 | + return f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f; |
66 | 38 | }
|
67 | 39 | } else if (M <= 2048) {
|
68 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( |
69 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 40 | + return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f; |
70 | 41 | } else {
|
71 | 42 | if (N <= 1024) {
|
72 |
| - return f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f( |
73 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 43 | + return f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f; |
74 | 44 | } else {
|
75 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( |
76 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 45 | + return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f; |
77 | 46 | }
|
78 | 47 | }
|
79 | 48 | } else {
|
80 | 49 | if (M <= 16) {
|
81 |
| - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( |
82 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 50 | + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; |
83 | 51 | } else if (M <= 32) {
|
84 | 52 | if (N <= 4096) {
|
85 |
| - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( |
86 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 53 | + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; |
87 | 54 | } else {
|
88 |
| - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( |
89 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 55 | + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; |
90 | 56 | }
|
91 | 57 | } else if (M <= 64) {
|
92 | 58 | if (N <= 2048) {
|
93 |
| - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( |
94 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 59 | + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; |
95 | 60 | } else if (N <= 4096) {
|
96 |
| - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( |
97 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 61 | + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; |
98 | 62 | } else {
|
99 |
| - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( |
100 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 63 | + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; |
101 | 64 | }
|
102 | 65 | } else if (M <= 128) {
|
103 | 66 | if (N <= 1024) {
|
104 |
| - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( |
105 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 67 | + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; |
106 | 68 | } else if (N <= 2048) {
|
107 |
| - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( |
108 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 69 | + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; |
109 | 70 | } else if (N <= 4096) {
|
110 |
| - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( |
111 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 71 | + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; |
112 | 72 | } else {
|
113 |
| - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( |
114 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 73 | + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; |
115 | 74 | }
|
116 | 75 | } else if (M <= 256) {
|
117 | 76 | if (N <= 1024) {
|
118 |
| - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( |
119 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 77 | + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; |
120 | 78 | } else if (N <= 2048) {
|
121 |
| - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( |
122 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 79 | + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; |
123 | 80 | } else if (N <= 4096) {
|
124 |
| - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( |
125 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 81 | + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; |
126 | 82 | } else {
|
127 |
| - return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f( |
128 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 83 | + return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f; |
129 | 84 | }
|
130 | 85 | } else if (M <= 512) {
|
131 | 86 | if (N <= 1024) {
|
132 |
| - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( |
133 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 87 | + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; |
134 | 88 | } else if (N <= 2048) {
|
135 |
| - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( |
136 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 89 | + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; |
137 | 90 | } else if (N <= 4096 || use_fast_accum == false) {
|
138 |
| - return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f( |
139 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 91 | + return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f; |
140 | 92 | } else {
|
141 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t( |
142 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 93 | + return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t; |
143 | 94 | }
|
144 | 95 | } else if (M <= 1024) {
|
145 | 96 | if (N <= 1024) {
|
146 |
| - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( |
147 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 97 | + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; |
148 | 98 | } else if (N <= 2048 || use_fast_accum == false) {
|
149 |
| - return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f( |
150 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 99 | + return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f; |
151 | 100 | } else {
|
152 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t( |
153 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 101 | + return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t; |
154 | 102 | }
|
155 | 103 | } else {
|
156 | 104 | if (M <= 2048 && N <= 1024) {
|
157 |
| - return f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f( |
158 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 105 | + return f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f; |
159 | 106 | } else if (K <= 4096 || use_fast_accum == false) {
|
160 |
| - return f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f( |
161 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 107 | + return f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f; |
162 | 108 | } else if (M > 8192 && N > 8192) {
|
163 |
| - return f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t( |
164 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 109 | + return f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t; |
165 | 110 | } else {
|
166 |
| - return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t( |
167 |
| - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 111 | + return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t; |
168 | 112 | }
|
169 | 113 | }
|
170 | 114 | }
|
171 | 115 | }
|
172 | 116 |
|
| 117 | +Kernel_f8f8bf16_rowwise get_kernel_via_tuning( |
| 118 | + int arch, |
| 119 | + int M, |
| 120 | + int N, |
| 121 | + int K, |
| 122 | + at::Tensor XQ, |
| 123 | + at::Tensor WQ, |
| 124 | + at::Tensor x_scale, |
| 125 | + at::Tensor w_scale, |
| 126 | + bool use_fast_accum, |
| 127 | + std::optional<at::Tensor> bias = std::nullopt, |
| 128 | + std::optional<at::Tensor> output = std::nullopt) { |
| 129 | + // One cache per kernel type |
| 130 | + static TuningCache cache("f8f8bf16_rowwise"); |
| 131 | + |
| 132 | + // Reducing amount of auto tuning by rounding up M to next power of 2. |
| 133 | + M = nextPowerOf2(M); |
| 134 | + // Use (M, N, K) shape as the key. |
| 135 | + const std::string shape_key = |
| 136 | + std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K); |
| 137 | + const auto& kernels = get_f8f8bf16_rowwise_kernels(arch); |
| 138 | + auto kernel = cache.findBestKernelMaybeAutotune( |
| 139 | + shape_key, |
| 140 | + kernels, |
| 141 | + XQ, |
| 142 | + WQ, |
| 143 | + x_scale, |
| 144 | + w_scale, |
| 145 | + use_fast_accum, |
| 146 | + bias, |
| 147 | + output); |
| 148 | + |
| 149 | + return kernel; |
| 150 | +} |
| 151 | + |
| 152 | +// FP8 Rowwise Cutlass kernel dispatch. |
| 153 | +at::Tensor dispatch_fp8_rowwise_kernel( |
| 154 | + at::Tensor XQ, |
| 155 | + at::Tensor WQ, |
| 156 | + at::Tensor x_scale, |
| 157 | + at::Tensor w_scale, |
| 158 | + bool use_fast_accum, |
| 159 | + std::optional<at::Tensor> bias = std::nullopt, |
| 160 | + std::optional<at::Tensor> output = std::nullopt) { |
| 161 | + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); |
| 162 | + int N = size_to_dim_(WQ.dim() - 1, WQ.sizes()); |
| 163 | + int K = XQ.size(-1); |
| 164 | + |
| 165 | + static int arch = -1; |
| 166 | + // Avoid expensive cudaGetDeviceProperties call. |
| 167 | + if (arch < 0) { |
| 168 | + cudaDeviceProp prop; |
| 169 | + cudaGetDeviceProperties(&prop, 0); |
| 170 | + if (prop.major >= 10) { |
| 171 | + arch = 10; |
| 172 | + int runtimeVersion; |
| 173 | + C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion)); |
| 174 | + TORCH_CHECK( |
| 175 | + runtimeVersion >= 12080, |
| 176 | + "FP8 GEMM on sm100a or above requires cuda >= 12.8"); |
| 177 | + } else { |
| 178 | + arch = 9; |
| 179 | + } |
| 180 | + } |
| 181 | + |
| 182 | + // Select kernel to run via heuristics or tuning. |
| 183 | + auto kernel = [&]() { |
| 184 | + if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) { |
| 185 | + return get_kernel_via_tuning( |
| 186 | + arch, |
| 187 | + M, |
| 188 | + N, |
| 189 | + K, |
| 190 | + XQ, |
| 191 | + WQ, |
| 192 | + x_scale, |
| 193 | + w_scale, |
| 194 | + use_fast_accum, |
| 195 | + bias, |
| 196 | + output); |
| 197 | + } else { |
| 198 | + return get_kernel_via_heuristic(arch, M, N, K, use_fast_accum); |
| 199 | + } |
| 200 | + }(); |
| 201 | + // Invoke kernel |
| 202 | + return kernel(XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); |
| 203 | +} |
| 204 | + |
173 | 205 | void f8f8bf16_rowwise_out(
|
174 | 206 | at::Tensor XQ, // FP8
|
175 | 207 | at::Tensor WQ, // FP8
|
|
0 commit comments