8
8
9
9
#include < ATen/ATen.h>
10
10
#include < ATen/cuda/CUDAContext.h>
11
+ #include < fmt/core.h>
11
12
12
13
#include " bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh"
14
+ #include " fbgemm_gpu/quantize/common/tuning_cache.hpp"
15
+ #include " fbgemm_gpu/quantize/common/utils.h"
13
16
14
17
namespace fbgemm_gpu {
15
18
16
19
#if CUDART_VERSION >= 12000
17
20
18
- // BF16 grouped cutlass kernel dispatch.
21
+ namespace {
22
+ TuningCache& getTuningCache () {
23
+ // This kernel has multiple APIs templated based on InputType, so we use this
24
+ // to have a single cache instance across APIs.
25
+ static TuningCache cache (" bf16bf16bf16_grouped" );
26
+ return cache;
27
+ }
28
+ } // namespace
29
+
19
30
template <typename InputType>
20
- at::Tensor dispatch_bf16_grouped_kernel (
21
- int G,
22
- int total_M,
23
- int N,
24
- int K,
25
- InputType X, // BF16
26
- InputType W, // BF16
27
- at::Tensor output,
28
- std::optional<at::Tensor> zero_start_index_M = std::nullopt ,
29
- std::optional<at::Tensor> M_sizes = std::nullopt ) {
31
+ Kernel_bf16bf16bf16_grouped<InputType>
32
+ get_kernel_via_heuristic (int G, int total_M, int N, int K) {
30
33
// Use heuristics to pick best kernel implementation.
31
34
32
35
// Llama4 128E
33
36
if (G == 128 ) {
34
37
if (N == 5120 && K == 1024 ) {
35
38
if (total_M <= 128 ) {
36
- return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
37
- X, W, output, zero_start_index_M, M_sizes);
39
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
38
40
} else if (total_M <= 256 ) {
39
- return bf16bf16bf16_grouped_128_32_128_2_1_1_t (
40
- X, W, output, zero_start_index_M, M_sizes);
41
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_t ;
41
42
} else if (total_M <= 2048 ) {
42
- return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
43
- X, W, output, zero_start_index_M, M_sizes);
43
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
44
44
} else if (total_M <= 4096 ) {
45
- return bf16bf16bf16_grouped_128_32_128_2_1_1_f (
46
- X, W, output, zero_start_index_M, M_sizes);
45
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_f;
47
46
} else if (total_M <= 8192 ) {
48
- return bf16bf16bf16_grouped_128_64_128_1_1_1_f (
49
- X, W, output, zero_start_index_M, M_sizes);
47
+ return bf16bf16bf16_grouped_128_64_128_1_1_1_f;
50
48
} else if (total_M <= 16384 ) {
51
- return bf16bf16bf16_grouped_128_128_128_2_1_1_t (
52
- X, W, output, zero_start_index_M, M_sizes);
49
+ return bf16bf16bf16_grouped_128_128_128_2_1_1_t ;
53
50
} else {
54
- return bf16bf16bf16_grouped_128_256_128_2_1_1_f (
55
- X, W, output, zero_start_index_M, M_sizes);
51
+ return bf16bf16bf16_grouped_128_256_128_2_1_1_f;
56
52
}
57
53
}
58
54
59
55
if (N == 2048 && K == 5120 ) {
60
56
if (total_M <= 2048 ) {
61
- return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
62
- X, W, output, zero_start_index_M, M_sizes);
57
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
63
58
} else {
64
- return bf16bf16bf16_grouped_128_128_128_2_1_1_t (
65
- X, W, output, zero_start_index_M, M_sizes);
59
+ return bf16bf16bf16_grouped_128_128_128_2_1_1_t ;
66
60
}
67
61
}
68
62
}
@@ -71,71 +65,102 @@ at::Tensor dispatch_bf16_grouped_kernel(
71
65
if (G == 16 ) {
72
66
if (N == 5120 && K == 1024 ) {
73
67
if (total_M <= 32 ) {
74
- return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
75
- X, W, output, zero_start_index_M, M_sizes);
68
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
76
69
} else if (total_M <= 64 ) {
77
- return bf16bf16bf16_grouped_128_32_128_2_1_1_t (
78
- X, W, output, zero_start_index_M, M_sizes);
70
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_t ;
79
71
} else if (total_M <= 256 ) {
80
- return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
81
- X, W, output, zero_start_index_M, M_sizes);
72
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
82
73
} else if (total_M <= 512 ) {
83
- return bf16bf16bf16_grouped_128_32_128_2_1_1_t (
84
- X, W, output, zero_start_index_M, M_sizes);
74
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_t ;
85
75
} else if (total_M <= 1024 ) {
86
- return bf16bf16bf16_grouped_128_64_128_2_1_1_t (
87
- X, W, output, zero_start_index_M, M_sizes);
76
+ return bf16bf16bf16_grouped_128_64_128_2_1_1_t ;
88
77
} else {
89
- return bf16bf16bf16_grouped_128_256_128_2_1_1_f (
90
- X, W, output, zero_start_index_M, M_sizes);
78
+ return bf16bf16bf16_grouped_128_256_128_2_1_1_f;
91
79
}
92
80
}
93
81
94
82
if (N == 2048 && K == 5120 ) {
95
83
if (total_M <= 16 ) {
96
- return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
97
- X, W, output, zero_start_index_M, M_sizes);
84
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
98
85
} else if (total_M <= 64 ) {
99
- return bf16bf16bf16_grouped_128_32_128_2_1_1_f (
100
- X, W, output, zero_start_index_M, M_sizes);
86
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_f;
101
87
} else if (total_M <= 256 ) {
102
- return bf16bf16bf16_grouped_128_16_128_2_1_1_f (
103
- X, W, output, zero_start_index_M, M_sizes);
88
+ return bf16bf16bf16_grouped_128_16_128_2_1_1_f;
104
89
} else if (total_M <= 512 ) {
105
- return bf16bf16bf16_grouped_128_32_128_2_1_1_f (
106
- X, W, output, zero_start_index_M, M_sizes);
90
+ return bf16bf16bf16_grouped_128_32_128_2_1_1_f;
107
91
} else if (total_M <= 1024 ) {
108
- return bf16bf16bf16_grouped_128_64_128_1_1_1_f (
109
- X, W, output, zero_start_index_M, M_sizes);
92
+ return bf16bf16bf16_grouped_128_64_128_1_1_1_f;
110
93
} else {
111
- return bf16bf16bf16_grouped_128_128_128_2_1_1_t (
112
- X, W, output, zero_start_index_M, M_sizes);
94
+ return bf16bf16bf16_grouped_128_128_128_2_1_1_t ;
113
95
}
114
96
}
115
97
}
116
98
117
99
// Fallback to legacy heuristic for now.
118
100
if (total_M <= 16 ) {
119
- return bf16bf16bf16_grouped_128_16_128_1_1_1_f (
120
- X, W, output, zero_start_index_M, M_sizes);
101
+ return bf16bf16bf16_grouped_128_16_128_1_1_1_f;
121
102
} else if (total_M <= 32 ) {
122
- return bf16bf16bf16_grouped_128_32_128_1_1_1_f (
123
- X, W, output, zero_start_index_M, M_sizes);
103
+ return bf16bf16bf16_grouped_128_32_128_1_1_1_f;
124
104
} else if (total_M <= 64 ) {
125
- return bf16bf16bf16_grouped_128_64_128_1_1_1_f (
126
- X, W, output, zero_start_index_M, M_sizes);
105
+ return bf16bf16bf16_grouped_128_64_128_1_1_1_f;
127
106
} else if (total_M <= 128 ) {
128
- return bf16bf16bf16_grouped_128_128_128_1_1_1_f (
129
- X, W, output, zero_start_index_M, M_sizes);
107
+ return bf16bf16bf16_grouped_128_128_128_1_1_1_f;
130
108
} else if (total_M <= 512 ) {
131
- return bf16bf16bf16_grouped_256_128_128_2_1_1_f (
132
- X, W, output, zero_start_index_M, M_sizes);
109
+ return bf16bf16bf16_grouped_256_128_128_2_1_1_f;
133
110
} else {
134
- return bf16bf16bf16_grouped_128_256_128_2_1_1_f (
135
- X, W, output, zero_start_index_M, M_sizes);
111
+ return bf16bf16bf16_grouped_128_256_128_2_1_1_f;
136
112
}
137
113
}
138
114
115
+ template <typename InputType>
116
+ Kernel_bf16bf16bf16_grouped<InputType> get_kernel_via_tuning (
117
+ int G,
118
+ int total_M,
119
+ int N,
120
+ int K,
121
+ InputType X, // BF16
122
+ InputType W, // BF16
123
+ at::Tensor output,
124
+ std::optional<at::Tensor> zero_start_index_M = std::nullopt ,
125
+ std::optional<at::Tensor> M_sizes = std::nullopt ) {
126
+ auto & cache = getTuningCache ();
127
+
128
+ // Reducing amount of auto tuning by rounding up total_m to next power of 2.
129
+ total_M = nextPowerOf2 (total_M);
130
+ // Use (total_M, N, K, G) shape as the key.
131
+ const std::string shape_key = fmt::format (" {}_{}_{}_{}" , total_M, N, K, G);
132
+ const auto & kernels = get_bf16bf16bf16_grouped_kernels<InputType>();
133
+ auto kernel = cache.findBestKernelMaybeAutotune (
134
+ shape_key, kernels, X, W, output, zero_start_index_M, M_sizes);
135
+
136
+ return kernel;
137
+ }
138
+
139
+ // BF16 grouped cutlass kernel dispatch.
140
+ template <typename InputType>
141
+ at::Tensor dispatch_bf16_grouped_kernel (
142
+ int G,
143
+ int total_M,
144
+ int N,
145
+ int K,
146
+ InputType X, // BF16
147
+ InputType W, // BF16
148
+ at::Tensor output,
149
+ std::optional<at::Tensor> zero_start_index_M = std::nullopt ,
150
+ std::optional<at::Tensor> M_sizes = std::nullopt ) {
151
+ // Select kernel to run via heuristics or tuning.
152
+ auto kernel = [&]() {
153
+ if (std::getenv (" FBGEMM_AUTOTUNE_ENABLE" )) {
154
+ return get_kernel_via_tuning (
155
+ G, total_M, N, K, X, W, output, zero_start_index_M, M_sizes);
156
+ } else {
157
+ return get_kernel_via_heuristic<InputType>(G, total_M, N, K);
158
+ }
159
+ }();
160
+ // Invoke kernel
161
+ return kernel (X, W, output, zero_start_index_M, M_sizes);
162
+ }
163
+
139
164
template <typename OutputType>
140
165
OutputType _bf16bf16bf16_grouped (at::TensorList X, at::TensorList W) {
141
166
at::Tensor Y;
0 commit comments