Skip to content

Commit 15fa951

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Optimize wgrad CUTLASS grouped gemm (#4891)
Summary: Pull Request resolved: #4891 X-link: facebookresearch/FBGEMM#1916 - Make wgrad CUTLASS grouped gemm return float32 output when wgrad is provided, respecting e2e - Optimize general heuristic - Make tests cover wgrad accum with float32 output Reviewed By: q10 Differential Revision: D82700455
1 parent cc0fd3c commit 15fa951

13 files changed

+1078
-75
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu

Lines changed: 785 additions & 29 deletions
Large diffs are not rendered by default.
Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,18 @@
1010

1111
namespace fbgemm_gpu {
1212

13-
at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f(
13+
at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t(
1414
at::Tensor X, // BF16
1515
at::Tensor W, // BF16
1616
at::Tensor M_sizes,
1717
at::Tensor output,
1818
bool output_accum) {
1919
if (output_accum) {
20-
return bf16bf16bf16_grouped_wgrad_impl<256, 128, 128, 1, 2, 1, true, false>(
20+
return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 1, 1, true, true>(
2121
X, W, M_sizes, output);
2222
} else {
23-
return bf16bf16bf16_grouped_wgrad_impl<
24-
256,
25-
128,
26-
128,
27-
1,
28-
2,
29-
1,
30-
false,
31-
false>(X, W, M_sizes, output);
23+
return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 1, 1, false, true>(
24+
X, W, M_sizes, output);
3225
}
3326
}
3427

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@
1010

1111
namespace fbgemm_gpu {
1212

13-
at::Tensor bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f(
13+
at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f(
1414
at::Tensor X, // BF16
1515
at::Tensor W, // BF16
1616
at::Tensor M_sizes,
1717
at::Tensor output,
1818
bool output_accum) {
1919
if (output_accum) {
20-
return bf16bf16bf16_grouped_wgrad_impl<128, 256, 128, 1, 1, 1, true, false>(
20+
return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 4, 1, true, false>(
2121
X, W, M_sizes, output);
2222
} else {
2323
return bf16bf16bf16_grouped_wgrad_impl<
2424
128,
25-
256,
25+
128,
2626
128,
2727
1,
28-
1,
28+
4,
2929
1,
3030
false,
3131
false>(X, W, M_sizes, output);
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,23 @@
1010

1111
namespace fbgemm_gpu {
1212

13-
at::Tensor bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f(
13+
at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f(
1414
at::Tensor X, // BF16
1515
at::Tensor W, // BF16
1616
at::Tensor M_sizes,
1717
at::Tensor output,
1818
bool output_accum) {
1919
if (output_accum) {
20-
return bf16bf16bf16_grouped_wgrad_impl<128, 256, 128, 1, 2, 1, true, false>(
20+
return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 1, 1, true, false>(
2121
X, W, M_sizes, output);
2222
} else {
2323
return bf16bf16bf16_grouped_wgrad_impl<
2424
128,
25-
256,
2625
128,
27-
1,
26+
128,
2827
2,
2928
1,
29+
1,
3030
false,
3131
false>(X, W, M_sizes, output);
3232
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_wgrad_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor M_sizes,
17+
at::Tensor output,
18+
bool output_accum) {
19+
if (output_accum) {
20+
return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 1, 1, true, true>(
21+
X, W, M_sizes, output);
22+
} else {
23+
return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 1, 1, false, true>(
24+
X, W, M_sizes, output);
25+
}
26+
}
27+
28+
} // namespace fbgemm_gpu
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_wgrad_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor M_sizes,
17+
at::Tensor output,
18+
bool output_accum) {
19+
if (output_accum) {
20+
return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 2, 1, true, false>(
21+
X, W, M_sizes, output);
22+
} else {
23+
return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 2, 1, false, false>(
24+
X, W, M_sizes, output);
25+
}
26+
}
27+
28+
} // namespace fbgemm_gpu
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_wgrad_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor M_sizes,
17+
at::Tensor output,
18+
bool output_accum) {
19+
if (output_accum) {
20+
return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 1, 1, true, false>(
21+
X, W, M_sizes, output);
22+
} else {
23+
return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 1, 1, false, false>(
24+
X, W, M_sizes, output);
25+
}
26+
}
27+
28+
} // namespace fbgemm_gpu
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_wgrad_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor M_sizes,
17+
at::Tensor output,
18+
bool output_accum) {
19+
if (output_accum) {
20+
return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 1, 1, true, false>(
21+
X, W, M_sizes, output);
22+
} else {
23+
return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 1, 1, false, false>(
24+
X, W, M_sizes, output);
25+
}
26+
}
27+
28+
} // namespace fbgemm_gpu
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_wgrad_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor M_sizes,
17+
at::Tensor output,
18+
bool output_accum) {
19+
if (output_accum) {
20+
return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 2, 1, true, false>(
21+
X, W, M_sizes, output);
22+
} else {
23+
return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 2, 1, false, false>(
24+
X, W, M_sizes, output);
25+
}
26+
}
27+
28+
} // namespace fbgemm_gpu
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_wgrad_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor M_sizes,
17+
at::Tensor output,
18+
bool output_accum) {
19+
if (output_accum) {
20+
return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 4, 1, true, false>(
21+
X, W, M_sizes, output);
22+
} else {
23+
return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 4, 1, false, false>(
24+
X, W, M_sizes, output);
25+
}
26+
}
27+
28+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)