Skip to content

Commit 7937cb8

Browse files
renganxufacebook-github-bot
authored andcommitted
Support general K in group gemm (#4636)
Summary: Pull Request resolved: #4636 X-link: facebookresearch/FBGEMM#1668 The existing group gemm does not work when K is not a multiple of 8. In this case, TMA load cannot be used. This diff adds the support for this general case. It works by adding mask when loading the data. With this change, when K is not a multiple of 8, using fuse_scatter_add results in larger numerical discrepancies that cause the unit test to fail. This is most likely due to atomic add operations in scatter_add or increased rounding errors with negative numbers, since the unit test uses randn. However, if the unit test uses rand to generate only positive numbers, then the test passes successfully. For now, in this case we just disable fused scatter_add in unit test, and do not allow it in group gemm impl. Reviewed By: sgrigory Differential Revision: D79393881 fbshipit-source-id: a04ad8ed1be64ac31b5db325524b44e6eabf7de3
1 parent 5d384e8 commit 7937cb8

File tree

2 files changed

+51
-16
lines changed

2 files changed

+51
-16
lines changed

fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def msg(s: str) -> str:
156156
G=st.sampled_from([1, 4, 16, 128]),
157157
M=st.sampled_from([0, 128, 2048, 16384]),
158158
N=st.sampled_from([256]),
159-
K=st.sampled_from([256]),
159+
K=st.sampled_from([100, 256, 257]),
160160
warp_specialization=st.sampled_from(
161161
[True, False] if torch.cuda.is_available() and _HAS_WS_SUPPORT else [False]
162162
),
@@ -179,9 +179,14 @@ def test_grouped_gemm_bf16(
179179
warp_specialization: bool,
180180
fuse_scatter_add: bool,
181181
) -> None:
182+
if K % 8 != 0:
183+
# When K is not a multiple of 8, using fuse_scatter_add has large numerical discrepancy,
184+
# possibly due to atomic add in scatter_add or larger rounding error with negative numbers.
185+
fuse_scatter_add = False
186+
182187
torch.manual_seed(0)
183188

184-
device = torch.device("cuda")
189+
device = torch.accelerator.current_accelerator()
185190
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
186191
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
187192
m_ends, _ = torch.sort(

fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,10 @@ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
149149
config.num_consumer_groups,
150150
kw.get("USE_TMA_LOAD_ON_SCALES", False),
151151
)
152-
G, M, N, K = (
152+
G, M, N = (
153153
named_args["G"],
154154
named_args["M_BUCKET"],
155155
named_args["N"],
156-
named_args["K"],
157156
)
158157

159158
# 1. make sure we have enough smem
@@ -198,11 +197,7 @@ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
198197
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
199198
continue
200199

201-
# 6. make sure K can be evenly divided
202-
if K % BLOCK_K != 0:
203-
continue
204-
205-
# 7. make sure we can partition for ws
200+
# 6. make sure we can partition for ws
206201
if use_warp_specialization:
207202
if num_warps != 4:
208203
continue
@@ -302,8 +297,9 @@ def _fbgemm_grouped_gemm(
302297
tile_n_idx = gidx // num_m_tiles
303298

304299
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
305-
tl.static_assert(K % BLOCK_SIZE_K == 0)
300+
306301
if USE_TMA_LOAD:
302+
tl.static_assert(K % BLOCK_SIZE_K == 0)
307303
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
308304
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
309305
for k_offset in range(0, K, BLOCK_SIZE_K):
@@ -338,8 +334,18 @@ def _fbgemm_grouped_gemm(
338334
+ offs_k[None, :]
339335
)
340336
for k_offset in range(0, K, BLOCK_SIZE_K):
341-
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
342-
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
337+
updated_k_offset = k_offset + offs_k
338+
updated_k_offset_mask = updated_k_offset[None, :] < K # type: ignore[16]
339+
a = tl.load(
340+
a_ptrs,
341+
mask=((offs_am[:, None] < m_size) & updated_k_offset_mask),
342+
other=0.0,
343+
)
344+
b = tl.load(
345+
b_ptrs,
346+
mask=((offs_bn[:, None] < n_size) & updated_k_offset_mask),
347+
other=0.0,
348+
)
343349
accumulator += tl.dot(a, b.T)
344350
a_ptrs += BLOCK_SIZE_K
345351
b_ptrs += BLOCK_SIZE_K
@@ -960,20 +966,28 @@ def _grouped_gemm(
960966

961967
if USE_TMA_LOAD and not utils.HAS_TMA_DESC:
962968
USE_TMA_LOAD = False
963-
warnings.warn("TMA load is disabled as there is no TMA descriptor support!")
969+
warnings.warn(
970+
"TMA load is disabled as there is no TMA descriptor support!", stacklevel=2
971+
)
964972

965973
if USE_TMA_STORE and not utils.HAS_TMA_DESC:
966974
USE_TMA_STORE = False
967-
warnings.warn("TMA store is disabled as there is no TMA descriptor support!")
975+
warnings.warn(
976+
"TMA store is disabled as there is no TMA descriptor support!", stacklevel=2
977+
)
968978

969979
# TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
970980
if use_warp_specialization and torch.version.hip:
971-
warnings.warn("Warp specialization is disabled as it is not supported on ROCm.")
981+
warnings.warn(
982+
"Warp specialization is disabled as it is not supported on ROCm.",
983+
stacklevel=2,
984+
)
972985
use_warp_specialization = False
973986

974987
if use_warp_specialization and not _HAS_WS_SUPPORT:
975988
warnings.warn(
976-
"Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs."
989+
"Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.",
990+
stacklevel=2,
977991
)
978992
use_warp_specialization = False
979993

@@ -991,6 +1005,22 @@ def _grouped_gemm(
9911005
N = w.shape[0] // G
9921006
assert K == w.shape[1]
9931007

1008+
if K % 8 != 0:
1009+
use_warp_specialization = False
1010+
USE_TMA_LOAD = False
1011+
USE_TMA_STORE = False
1012+
warnings.warn(
1013+
f"TMA load and warp specialization are disabled since K is not a multiple of 8: {K=}.",
1014+
stacklevel=2,
1015+
)
1016+
assert (
1017+
x_scale is None
1018+
), f"Quantisation is not supported yet when K is not a multiple of 8: {K=}"
1019+
1020+
assert (
1021+
output_tensor is None
1022+
), f"Fused scatter add has large rounding error when K is not a multiple of 8: {K=}"
1023+
9941024
if output_tensor is None:
9951025
FUSE_SCATTER_ADD = False
9961026
assert scatter_add_indices is None

0 commit comments

Comments
 (0)