Skip to content

Commit 823ed36

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Simplify CK FP8 Kernel Launch and enable FP16 Outputs. (#4233)
Summary: X-link: facebookresearch/FBGEMM#1311 This diff does some template cleanup of FP8 rowwise AMD kernels, specifically moving as much specialization as possible from the individual kernels to the common header file. This makes it a lot easier to do auto-generation of kernels or add new features going forward and should be functionally the same as before. I leverage this new cleaner infrastructure to also allow FP16 outputs in a pretty seamless way. This allows us to introduce f8f8f16_rowwise for the AMD backend, which is needed for some recommendation system use cases. Pull Request resolved: #4233 Reviewed By: jianyuh, jiawenliu64 Differential Revision: D74770197 Pulled By: jwfromm
1 parent ba16adc commit 823ed36

File tree

92 files changed

+1211
-2702
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+1211
-2702
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,8 @@ RowwiseKernel rowwise_dispatch(int M, int N, int K) {
496496
return rowwise_heuristic_dispatch(M, N, K);
497497
}
498498

499-
at::Tensor f8f8bf16_rowwise_wrapper(
499+
template <at::ScalarType OUTPUT_DTYPE>
500+
at::Tensor f8f8_rowwise_wrapper(
500501
at::Tensor XQ,
501502
at::Tensor WQ,
502503
at::Tensor x_scale,
@@ -513,6 +514,7 @@ at::Tensor f8f8bf16_rowwise_wrapper(
513514
(x_scale.dtype() == at::kFloat) && (w_scale.dtype() == at::kFloat),
514515
"Scales must be float32.");
515516
TORCH_CHECK(use_fast_accum, "AMD does not support disabling use_fast_accum.");
517+
TORCH_CHECK(!bias.has_value(), "AMD does not support fused bias.");
516518

517519
// Check inputs are in expected format.
518520
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
@@ -530,7 +532,7 @@ at::Tensor f8f8bf16_rowwise_wrapper(
530532
// Handle case where an input dimension is zero.
531533
if (M == 0 || N == 0 || K == 0) {
532534
// Return a tensor of zeros to handle case where K is 0.
533-
return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16));
535+
return at::zeros(out_sizes, XQ.options().dtype(OUTPUT_DTYPE));
534536
}
535537

536538
// Prepare output tensor if needed.
@@ -540,9 +542,9 @@ at::Tensor f8f8bf16_rowwise_wrapper(
540542
// Make sure the provided output has the proper shape and dtype.
541543
int Y_M = size_to_dim_(Y.dim() - 1, Y.sizes());
542544
TORCH_CHECK(Y_M == M && Y.sizes().vec().back() == N);
543-
TORCH_CHECK(Y.dtype() == at::kBFloat16);
545+
TORCH_CHECK(Y.dtype() == OUTPUT_DTYPE);
544546
} else {
545-
Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16));
547+
Y = at::empty(out_sizes, XQ.options().dtype(OUTPUT_DTYPE));
546548
}
547549

548550
RowwiseKernel rowwise_impl = rowwise_dispatch(M, N, K);
@@ -557,7 +559,19 @@ at::Tensor f8f8bf16_rowwise(
557559
std::optional<at::Tensor> bias,
558560
bool use_fast_accum) {
559561
// Invoke f8f8bf16 rowwise without preallocated output.
560-
return f8f8bf16_rowwise_wrapper(
562+
return f8f8_rowwise_wrapper<at::kBFloat16>(
563+
XQ, WQ, x_scale, w_scale, bias, use_fast_accum);
564+
}
565+
566+
at::Tensor f8f8f16_rowwise(
567+
at::Tensor XQ,
568+
at::Tensor WQ,
569+
at::Tensor x_scale,
570+
at::Tensor w_scale,
571+
std::optional<at::Tensor> bias,
572+
bool use_fast_accum) {
573+
// Invoke f8f8bf16 rowwise without preallocated output.
574+
return f8f8_rowwise_wrapper<at::kHalf>(
561575
XQ, WQ, x_scale, w_scale, bias, use_fast_accum);
562576
}
563577

@@ -570,7 +584,7 @@ void f8f8bf16_rowwise_out(
570584
std::optional<at::Tensor> bias,
571585
bool use_fast_accum) {
572586
// Invoke f8f8bf16 rowwise with preallocated output.
573-
f8f8bf16_rowwise_wrapper(
587+
f8f8_rowwise_wrapper<at::kBFloat16>(
574588
XQ, WQ, x_scale, w_scale, bias, use_fast_accum, output);
575589
}
576590

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ fp8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
1818
// A kernel that works well on small but not super tiny shapes.
19-
using DeviceGemmInstance = DeviceGemmHelper<
19+
return f8f8bf16_rowwise_wrapper<
2020
128,
2121
128,
2222
16,
@@ -32,7 +32,5 @@ fp8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_
3232
1,
3333
1,
3434
ck::BlockGemmPipelineScheduler::Interwave,
35-
ck::BlockGemmPipelineVersion::v2>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
35+
ck::BlockGemmPipelineVersion::v2>(XQ, WQ, x_scale, w_scale, Y, 1);
3836
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip

Lines changed: 17 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,55 +15,21 @@ fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
// Check if this input needs to be padded.
19-
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
20-
int N = WQ.size(0);
21-
int K = WQ.size(1);
22-
bool pad = (M % 128 != 0) || (N % 32 != 0) || (K % 128 != 0);
23-
24-
// This kernel seems optimal in the most purely compute bound tasks.
25-
if (pad) {
26-
using DeviceGemmInstance = DeviceGemmHelper<
27-
128,
28-
128,
29-
32,
30-
128,
31-
32,
32-
32,
33-
2,
34-
1,
35-
S<8, 16, 1>,
36-
S<8, 16, 1>,
37-
S<1, 16, 1, 8>,
38-
S<4, 4, 1>,
39-
1,
40-
1,
41-
ck::BlockGemmPipelineScheduler::Intrawave,
42-
ck::BlockGemmPipelineVersion::v2>;
43-
// Run kernel instance.
44-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
45-
XQ, WQ, x_scale, w_scale, Y);
46-
} else {
47-
using DeviceGemmInstance = DeviceGemmHelper<
48-
128,
49-
128,
50-
32,
51-
128,
52-
32,
53-
32,
54-
2,
55-
1,
56-
S<8, 16, 1>,
57-
S<8, 16, 1>,
58-
S<1, 16, 1, 8>,
59-
S<4, 4, 1>,
60-
1,
61-
1,
62-
ck::BlockGemmPipelineScheduler::Intrawave,
63-
ck::BlockGemmPipelineVersion::v2,
64-
ck::tensor_operation::device::GemmSpecialization::Default>;
65-
// Run kernel instance.
66-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
67-
XQ, WQ, x_scale, w_scale, Y);
68-
}
18+
return f8f8bf16_rowwise_wrapper<
19+
128,
20+
128,
21+
32,
22+
128,
23+
32,
24+
32,
25+
2,
26+
1,
27+
S<8, 16, 1>,
28+
S<8, 16, 1>,
29+
S<1, 16, 1, 8>,
30+
S<4, 4, 1>,
31+
1,
32+
1,
33+
ck::BlockGemmPipelineScheduler::Intrawave,
34+
ck::BlockGemmPipelineVersion::v2>(XQ, WQ, x_scale, w_scale, Y, 1);
6935
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
using DeviceGemmInstance = DeviceGemmHelper<
18+
return f8f8bf16_rowwise_wrapper<
1919
128,
2020
16,
2121
32,
@@ -31,9 +31,5 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v
3131
1,
3232
1,
3333
ck::BlockGemmPipelineScheduler::Interwave,
34-
ck::BlockGemmPipelineVersion::v2,
35-
ck::tensor_operation::device::GemmSpecialization::Default>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
34+
ck::BlockGemmPipelineVersion::v2>(XQ, WQ, x_scale, w_scale, Y, 1);
3835
}
39-

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_4_split_k.hip

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
using DeviceGemmInstance = DeviceGemmHelper<
18+
return f8f8bf16_rowwise_wrapper<
1919
128,
2020
16,
2121
32,
@@ -31,8 +31,5 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v
3131
1,
3232
1,
3333
ck::BlockGemmPipelineScheduler::Interwave,
34-
ck::BlockGemmPipelineVersion::v2,
35-
ck::tensor_operation::device::GemmSpecialization::Default>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y, 4);
34+
ck::BlockGemmPipelineVersion::v2>(XQ, WQ, x_scale, w_scale, Y, 4);
3835
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8_split_k.hip

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
using DeviceGemmInstance = DeviceGemmHelper<
18+
return f8f8bf16_rowwise_wrapper<
1919
128,
2020
16,
2121
32,
@@ -31,8 +31,5 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v
3131
1,
3232
1,
3333
ck::BlockGemmPipelineScheduler::Interwave,
34-
ck::BlockGemmPipelineVersion::v2,
35-
ck::tensor_operation::device::GemmSpecialization::Default>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y, 8);
34+
ck::BlockGemmPipelineVersion::v2>(XQ, WQ, x_scale, w_scale, Y, 8);
3835
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip

Lines changed: 17 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,53 +16,21 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
1818
// The smallest kernel we have available. Works well for memory bound shapes.
19-
20-
// Check if this input needs to be padded.
21-
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
22-
int N = WQ.size(0);
23-
int K = WQ.size(1);
24-
bool pad = (M % 16 != 0) || (N % 32 != 0) || (K % 128 != 0);
25-
if (pad) {
26-
using DeviceGemmInstance = DeviceGemmHelper<
27-
128,
28-
16,
29-
32,
30-
128,
31-
16,
32-
16,
33-
1,
34-
1,
35-
S<8, 16, 1>,
36-
S<8, 16, 1>,
37-
S<1, 16, 1, 8>,
38-
S<4, 4, 1>,
39-
1,
40-
1,
41-
ck::BlockGemmPipelineScheduler::Intrawave,
42-
ck::BlockGemmPipelineVersion::v1,
43-
ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
44-
// Run kernel instance.
45-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
46-
} else{
47-
using DeviceGemmInstance = DeviceGemmHelper<
48-
128,
49-
16,
50-
32,
51-
128,
52-
16,
53-
16,
54-
1,
55-
1,
56-
S<8, 16, 1>,
57-
S<8, 16, 1>,
58-
S<1, 16, 1, 8>,
59-
S<4, 4, 1>,
60-
1,
61-
1,
62-
ck::BlockGemmPipelineScheduler::Intrawave,
63-
ck::BlockGemmPipelineVersion::v1,
64-
ck::tensor_operation::device::GemmSpecialization::Default>;
65-
// Run kernel instance.
66-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
67-
}
19+
return f8f8bf16_rowwise_wrapper<
20+
128,
21+
16,
22+
32,
23+
128,
24+
16,
25+
16,
26+
1,
27+
1,
28+
S<8, 16, 1>,
29+
S<8, 16, 1>,
30+
S<1, 16, 1, 8>,
31+
S<4, 4, 1>,
32+
1,
33+
1,
34+
ck::BlockGemmPipelineScheduler::Intrawave,
35+
ck::BlockGemmPipelineVersion::v1>(XQ, WQ, x_scale, w_scale, Y, 1);
6836
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip

Lines changed: 17 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,52 +16,21 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
1818
// The smallest kernel we have available. Works well for memory bound shapes.
19-
20-
// Check if this input needs to be padded.
21-
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
22-
int N = WQ.size(0);
23-
int K = WQ.size(1);
24-
bool pad = (M % 16 != 0) || (N % 32 != 0) || (K % 128 != 0);
25-
if (pad) {
26-
using DeviceGemmInstance = DeviceGemmHelper<
27-
128,
28-
16,
29-
32,
30-
128,
31-
16,
32-
16,
33-
1,
34-
1,
35-
S<8, 16, 1>,
36-
S<8, 16, 1>,
37-
S<1, 16, 1, 8>,
38-
S<4, 4, 1>,
39-
1,
40-
1,
41-
ck::BlockGemmPipelineScheduler::Intrawave,
42-
ck::BlockGemmPipelineVersion::v2>;
43-
// Run kernel instance.
44-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
45-
} else{
46-
using DeviceGemmInstance = DeviceGemmHelper<
47-
128,
48-
16,
49-
32,
50-
128,
51-
16,
52-
16,
53-
1,
54-
1,
55-
S<8, 16, 1>,
56-
S<8, 16, 1>,
57-
S<1, 16, 1, 8>,
58-
S<4, 4, 1>,
59-
1,
60-
1,
61-
ck::BlockGemmPipelineScheduler::Intrawave,
62-
ck::BlockGemmPipelineVersion::v2,
63-
ck::tensor_operation::device::GemmSpecialization::Default>;
64-
// Run kernel instance.
65-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
66-
}
19+
return f8f8bf16_rowwise_wrapper<
20+
128,
21+
16,
22+
32,
23+
128,
24+
16,
25+
16,
26+
1,
27+
1,
28+
S<8, 16, 1>,
29+
S<8, 16, 1>,
30+
S<1, 16, 1, 8>,
31+
S<4, 4, 1>,
32+
1,
33+
1,
34+
ck::BlockGemmPipelineScheduler::Intrawave,
35+
ck::BlockGemmPipelineVersion::v2>(XQ, WQ, x_scale, w_scale, Y, 1);
6736
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2_8_split_k.hip

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
using DeviceGemmInstance = DeviceGemmHelper<
18+
return f8f8bf16_rowwise_wrapper<
1919
128,
2020
16,
2121
32,
@@ -31,8 +31,5 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v
3131
1,
3232
1,
3333
ck::BlockGemmPipelineScheduler::Intrawave,
34-
ck::BlockGemmPipelineVersion::v2,
35-
ck::tensor_operation::device::GemmSpecialization::Default>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y, 8);
34+
ck::BlockGemmPipelineVersion::v2>(XQ, WQ, x_scale, w_scale, Y, 8);
3835
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
using DeviceGemmInstance = DeviceGemmHelper<
18+
return f8f8bf16_rowwise_wrapper<
1919
128,
2020
16,
2121
32,
@@ -31,9 +31,5 @@ fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v
3131
1,
3232
1,
3333
ck::BlockGemmPipelineScheduler::Intrawave,
34-
ck::BlockGemmPipelineVersion::v1,
35-
ck::tensor_operation::device::GemmSpecialization::Default>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
34+
ck::BlockGemmPipelineVersion::v1>(XQ, WQ, x_scale, w_scale, Y, 1);
3835
}
39-

0 commit comments

Comments
 (0)