Skip to content

Commit 553b40e

Browse files
cthifacebook-github-bot
authored andcommitted
FP4 grouped refactor
Summary: X-link: facebookresearch/FBGEMM#1957 Split some clean-up/refactors from the core FP4 Torch API support to make the next diff more focused. - Removed `zero_start_index_M` as it's unused - Removed passing `G` into the kernel directly as it can be inferred - Rename `ElementComputeEpilogue` -> `ElementScale` - Add `namespace fbgemm_gpu` in `f4f4bf16_grouped_common.cuh` - Removed `num_x_scale_per_group` and `num_w_scale_per_group` as they are both unused - Removed un-neccesary cutlass headers in `f4f4bf16_grouped.cu` Differential Revision: D83166227
1 parent cf2dc81 commit 553b40e

13 files changed

+47
-162
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped.cu

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,6 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11-
#include <cutlass/util/device_memory.h>
12-
#include <cutlass/util/packed_stride.hpp>
13-
14-
// clang-format off
15-
// The fixed ordering of the headers is required for CUTLASS 3.2+
16-
#include <cute/tensor.hpp>
17-
#include <cutlass/gemm/collective/collective_builder.hpp> // @manual
18-
#include <cutlass/gemm/device/gemm_universal_adapter.h> // @manual
19-
#include <cutlass/epilogue/collective/collective_builder.hpp> // @manual
20-
// clang-format on
2111

2212
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
2313
#include "f4f4bf16_grouped/f4f4bf16_grouped_manifest.cuh"
@@ -160,14 +150,10 @@ at::Tensor dispatch_fp4_grouped_kernel(
160150
at::Tensor x_scale,
161151
at::Tensor w_scale,
162152
at::Tensor output,
163-
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
164153
std::optional<at::Tensor> M_sizes = std::nullopt,
165154
std::optional<at::Tensor> global_scale = std::nullopt,
166155
std::optional<at::Tensor> starting_row_after_padding = std::nullopt,
167156
bool use_mx = true) {
168-
TORCH_CHECK(
169-
zero_start_index_M.has_value() != M_sizes.has_value(),
170-
"One of zero_start_index_M or M_sizes must be provided.");
171157
TORCH_CHECK(M_sizes.has_value(), "M_sizes is assumed to be provided.");
172158
TORCH_CHECK(
173159
starting_row_after_padding.has_value(),
@@ -187,8 +173,6 @@ at::Tensor dispatch_fp4_grouped_kernel(
187173
x_scale,
188174
w_scale,
189175
output,
190-
G,
191-
zero_start_index_M,
192176
M_sizes,
193177
global_scale,
194178
starting_row_after_padding);
@@ -228,7 +212,6 @@ at::Tensor f4f4bf16_grouped_stacked(
228212
x_scale,
229213
w_scale,
230214
Y,
231-
std::nullopt,
232215
M_sizes,
233216
global_scale,
234217
starting_row_after_padding,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_128_128_256_1_1_1_f.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_f(
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
2020
at::Tensor output,
21-
int64_t G,
22-
std::optional<at::Tensor> zero_start_index_M,
2321
std::optional<at::Tensor> M_sizes,
2422
std::optional<at::Tensor> global_scale,
2523
std::optional<at::Tensor> starting_row_after_padding) {
@@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_f(
3634
x_scale,
3735
w_scale,
3836
output,
39-
G,
40-
zero_start_index_M,
4137
M_sizes,
4238
global_scale,
4339
starting_row_after_padding);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_128_128_256_1_1_1_t.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_t(
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
2020
at::Tensor output,
21-
int64_t G,
22-
std::optional<at::Tensor> zero_start_index_M,
2321
std::optional<at::Tensor> M_sizes,
2422
std::optional<at::Tensor> global_scale,
2523
std::optional<at::Tensor> starting_row_after_padding) {
@@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_t(
3634
x_scale,
3735
w_scale,
3836
output,
39-
G,
40-
zero_start_index_M,
4137
M_sizes,
4238
global_scale,
4339
starting_row_after_padding);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_128_64_256_1_1_1_f.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_128_64_256_1_1_1_f(
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
2020
at::Tensor output,
21-
int64_t G,
22-
std::optional<at::Tensor> zero_start_index_M,
2321
std::optional<at::Tensor> M_sizes,
2422
std::optional<at::Tensor> global_scale,
2523
std::optional<at::Tensor> starting_row_after_padding) {
@@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_128_64_256_1_1_1_f(
3634
x_scale,
3735
w_scale,
3836
output,
39-
G,
40-
zero_start_index_M,
4137
M_sizes,
4238
global_scale,
4339
starting_row_after_padding);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_128_64_256_1_1_1_t.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_128_64_256_1_1_1_t(
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
2020
at::Tensor output,
21-
int64_t G,
22-
std::optional<at::Tensor> zero_start_index_M,
2321
std::optional<at::Tensor> M_sizes,
2422
std::optional<at::Tensor> global_scale,
2523
std::optional<at::Tensor> starting_row_after_padding) {
@@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_128_64_256_1_1_1_t(
3634
x_scale,
3735
w_scale,
3836
output,
39-
G,
40-
zero_start_index_M,
4137
M_sizes,
4238
global_scale,
4339
starting_row_after_padding);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_256_128_256_2_1_1_f.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_128_256_2_1_1_f(
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
2020
at::Tensor output,
21-
int64_t G,
22-
std::optional<at::Tensor> zero_start_index_M,
2321
std::optional<at::Tensor> M_sizes,
2422
std::optional<at::Tensor> global_scale,
2523
std::optional<at::Tensor> starting_row_after_padding) {
@@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_128_256_2_1_1_f(
3634
x_scale,
3735
w_scale,
3836
output,
39-
G,
40-
zero_start_index_M,
4137
M_sizes,
4238
global_scale,
4339
starting_row_after_padding);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_256_128_256_2_1_1_t.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_128_256_2_1_1_t(
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
2020
at::Tensor output,
21-
int64_t G,
22-
std::optional<at::Tensor> zero_start_index_M,
2321
std::optional<at::Tensor> M_sizes,
2422
std::optional<at::Tensor> global_scale,
2523
std::optional<at::Tensor> starting_row_after_padding) {
@@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_128_256_2_1_1_t(
3634
x_scale,
3735
w_scale,
3836
output,
39-
G,
40-
zero_start_index_M,
4137
M_sizes,
4238
global_scale,
4339
starting_row_after_padding);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_256_256_256_2_1_1_f.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_256_256_2_1_1_f(
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
2020
at::Tensor output,
21-
int64_t G,
22-
std::optional<at::Tensor> zero_start_index_M,
2321
std::optional<at::Tensor> M_sizes,
2422
std::optional<at::Tensor> global_scale,
2523
std::optional<at::Tensor> starting_row_after_padding) {
@@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_256_256_2_1_1_f(
3634
x_scale,
3735
w_scale,
3836
output,
39-
G,
40-
zero_start_index_M,
4137
M_sizes,
4238
global_scale,
4339
starting_row_after_padding);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_256_256_256_2_1_1_t.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_256_256_2_1_1_t(
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
2020
at::Tensor output,
21-
int64_t G,
22-
std::optional<at::Tensor> zero_start_index_M,
2321
std::optional<at::Tensor> M_sizes,
2422
std::optional<at::Tensor> global_scale,
2523
std::optional<at::Tensor> starting_row_after_padding) {
@@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_256_256_2_1_1_t(
3634
x_scale,
3735
w_scale,
3836
output,
39-
G,
40-
zero_start_index_M,
4137
M_sizes,
4238
global_scale,
4339
starting_row_after_padding);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_256_64_256_2_1_1_f.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_64_256_2_1_1_f(
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
2020
at::Tensor output,
21-
int64_t G,
22-
std::optional<at::Tensor> zero_start_index_M,
2321
std::optional<at::Tensor> M_sizes,
2422
std::optional<at::Tensor> global_scale,
2523
std::optional<at::Tensor> starting_row_after_padding) {
@@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_64_256_2_1_1_f(
3634
x_scale,
3735
w_scale,
3836
output,
39-
G,
40-
zero_start_index_M,
4137
M_sizes,
4238
global_scale,
4339
starting_row_after_padding);

0 commit comments

Comments
 (0)