Skip to content

Commit 87dd0dd

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Enable CUTLASS grouped GEMM for pretraining wgrad on GB200 and H100 (resubmit)
Summary: - Enable CUTLASS grouped GEMM for llama4x pretraining wgrad on GB200 and H100 - Optimize performance of pretraining moe shapes on H100 - Support total_K in quantize_bench for wgrad - The FBGEMM relocation issue has been released for short-term, so resubmit. Passed all tests in T238469849 Differential Revision: D83001505
1 parent be84b43 commit 87dd0dd

File tree

43 files changed

+2595
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2595
-2
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,10 +486,16 @@ def print_kernels(kernels: Optional[List[str]]) -> List[QuantizeOpBase]:
486486
default=None,
487487
help="If set with grouped mode, repeat input shapes this many times. Comma separated list of groups to benchmark",
488488
)
489+
@click.option(
490+
"--total-K",
491+
default=None,
492+
help="If set, adjusts the K values to sum to this number. "
493+
"This can help simulate real grouped workloads in backward wgrad.",
494+
)
489495
@click.option(
490496
"--total-M",
491497
default=None,
492-
help="If set, Adjusts the M values to sum to this number. "
498+
help="If set, adjusts the M values to sum to this number. "
493499
"This can help simulate real grouped workloads.",
494500
)
495501
@click.option(
@@ -542,6 +548,7 @@ def invoke_main(
542548
pair_nk: bool,
543549
grouped: bool,
544550
groups: Optional[str],
551+
total_k: Optional[str],
545552
total_m: Optional[str],
546553
no_cuda_graph: bool,
547554
use_rotating_buffer_bench: bool,
@@ -553,6 +560,14 @@ def invoke_main(
553560
):
554561
if enable_amd_env_vars:
555562
set_amd_env_vars()
563+
564+
# Validate that total_m and total_k are mutually exclusive
565+
if total_m is not None and total_k is not None:
566+
raise ValueError(
567+
"total_m and total_k cannot be specified at the same time. "
568+
"Please provide only one of them."
569+
)
570+
556571
# If kernel filter is provided, parse it. Else, benchmark all kernels.
557572
all_kernels = kernels.strip().split(",") if kernels else None
558573
quantize_ops = collect_kernels_to_profile(all_kernels)
@@ -629,6 +644,17 @@ def invoke_main(
629644
for g in groups_list
630645
for b, _, n, k in MNK
631646
]
647+
elif total_k:
648+
MNK = [
649+
[
650+
[b] * g,
651+
[m] * g,
652+
[n] * g,
653+
generate_group_tensor(g, int(total_k)),
654+
]
655+
for g in groups_list
656+
for b, m, n, _ in MNK
657+
]
632658
else:
633659
MNK = [
634660
[[b] * g, [m] * g, [n] * g, [k] * g]

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2084,7 +2084,7 @@ def cuda(self) -> bool:
20842084
@register_quantize_op
20852085
class BF16GroupedGrad(QuantizeOpBase):
20862086
"""
2087-
BF16 grouped matmul with grad inputs backed by cutlass
2087+
BF16 grouped matmul with dgrad inputs in pretraining backed by cutlass
20882088
"""
20892089

20902090
def preprocess(self, x, w):
@@ -2126,6 +2126,52 @@ def cuda(self) -> bool:
21262126
return True
21272127

21282128

2129+
@register_quantize_op
2130+
class BF16GroupedWGrad(QuantizeOpBase):
2131+
"""
2132+
BF16 grouped matmul with wgrad inputs in pretraining backed by cutlass
2133+
"""
2134+
2135+
def preprocess(self, x, w):
2136+
# Get K values for each group
2137+
k_values = [xi.shape[1] for xi in x] # K dimension for each group
2138+
2139+
# Convert k_values into sizes tensor
2140+
k_sizes = torch.tensor(k_values).to(dtype=torch.int64, device=x[0].device)
2141+
2142+
x = torch.concat(x, dim=1).contiguous() # shape: (M, G*K)
2143+
w = torch.concat(w, dim=1).contiguous() # shape: (N, G*K)
2144+
2145+
# Transpose the follows to simulate wgrad shapes
2146+
x = x.t().contiguous() # shape: (G*K, M)
2147+
w = w.t().contiguous() # shape: (G*K, N)
2148+
2149+
# Return processed tensors
2150+
return x, w, k_sizes
2151+
2152+
def quantize(self, x, w, k_sizes):
2153+
return x, w, k_sizes
2154+
2155+
def compute(self, x, w, k_sizes):
2156+
return torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad(x, w, k_sizes)
2157+
2158+
def quantize_and_compute(self, x, w, k_sizes):
2159+
x, w, k_sizes = self.quantize(x, w, k_sizes)
2160+
return self.compute(x, w, k_sizes)
2161+
2162+
@property
2163+
def name(self) -> str:
2164+
return "bf16_grouped_wgrad"
2165+
2166+
@property
2167+
def hip(self) -> bool:
2168+
return False
2169+
2170+
@property
2171+
def cuda(self) -> bool:
2172+
return True
2173+
2174+
21292175
@register_quantize_op
21302176
class BF16GroupedStacked(QuantizeOpBase):
21312177
"""

0 commit comments

Comments
 (0)