Skip to content

Commit 5328d27

Browse files
sunfish2010facebook-github-bot
authored andcommitted
Triton early prune config fix (#4917)
Summary: Pull Request resolved: #4917 X-link: facebookresearch/FBGEMM#1941 Fix early prune config for cases not using warp-specializations. Reviewed By: renganxu, jasonjk-park Differential Revision: D83019498 fbshipit-source-id: f35883eefe8e437020722a1a738c4f2ef12e8f11
1 parent 1eacb9e commit 5328d27

File tree

1 file changed

+71
-2
lines changed

1 file changed

+71
-2
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,75 @@ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
129129
if dtype is None:
130130
dtype = named_args["c_ptr"].dtype
131131

132+
pruned_configs = []
133+
for config in configs:
134+
kw = config.kwargs
135+
(
136+
BLOCK_M,
137+
BLOCK_N,
138+
BLOCK_K,
139+
num_stages,
140+
use_tma_load_on_scales,
141+
) = (
142+
kw["BLOCK_SIZE_M"],
143+
kw["BLOCK_SIZE_N"],
144+
kw["BLOCK_SIZE_K"],
145+
config.num_stages,
146+
kw.get("USE_TMA_LOAD_ON_SCALES", False),
147+
)
148+
G, M, N = (
149+
named_args["G"],
150+
named_args["M_BUCKET"],
151+
named_args["N"],
152+
)
153+
154+
# 1. make sure we have enough smem
155+
max_shared_memory = driver.active.utils.get_device_properties(device)[
156+
"max_shared_mem"
157+
]
158+
if torch.version.hip:
159+
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
160+
else:
161+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
162+
if required_shared_memory > max_shared_memory:
163+
continue
164+
165+
M_PER_GROUP = M // G
166+
MIN_M_TILES = 32 if torch.version.hip else 64
167+
# 2. make sure we don't load M tiles that are too big
168+
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
169+
continue
170+
# 3. make sure we don't load N tiles that are too small
171+
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
172+
continue
173+
174+
num_sm = driver.active.utils.get_device_properties(device)[
175+
"multiprocessor_count"
176+
]
177+
N_TILES = (N + BLOCK_N - 1) // BLOCK_N
178+
MIN_N_TILES = 32 if torch.version.hip else 64
179+
# 4. make sure we don't load N tiles that are too big
180+
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
181+
continue
182+
# 5. make sure we don't load N tiles that are too small
183+
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
184+
continue
185+
if dtsize >= 2:
186+
if use_tma_load_on_scales:
187+
continue
188+
pruned_configs.append(config)
189+
190+
return pruned_configs
191+
192+
193+
def early_config_prune_ws(configs, named_args, dtsize=None, dtype=None, **kwargs):
194+
device = torch.cuda.current_device()
195+
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
196+
if dtsize is None:
197+
dtsize = named_args["c_ptr"].element_size()
198+
if dtype is None:
199+
dtype = named_args["c_ptr"].dtype
200+
132201
pruned_configs = []
133202
for config in configs:
134203
kw = config.kwargs
@@ -384,7 +453,7 @@ def _fbgemm_grouped_gemm(
384453
@triton.autotune(
385454
configs=_NV_WS_CONFIGS,
386455
key=["G", "M_BUCKET", "N", "K"],
387-
prune_configs_by={"early_config_prune": early_config_prune},
456+
prune_configs_by={"early_config_prune": early_config_prune_ws},
388457
restore_value=["c_ptr"], # restore for scatter_add fusion
389458
)
390459
@triton.jit
@@ -712,7 +781,7 @@ def _fbgemm_grouped_gemm_fp8_rowwise(
712781
key=["G", "M_BUCKET", "N", "K"],
713782
prune_configs_by={
714783
"early_config_prune": functools.partial(
715-
early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
784+
early_config_prune_ws, dtype=TT_FP8_DTYPE, dtsize=1
716785
)
717786
},
718787
restore_value=["c_ptr"], # restore for scatter_add fusion

0 commit comments

Comments
 (0)