Skip to content

Commit 53f9e51

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Back out "Enable CUTLASS grouped GEMM for pretraining wgrad on GB200 and H100" (#4892)
Summary: Pull Request resolved: #4892 X-link: facebookresearch/FBGEMM#1918 Original commit changeset: 302c7b81a9c0 Original Phabricator Diff: D82325651 Will add it back soon when the relocation issue in fbgemm is resolved Reviewed By: cthi Differential Revision: D82746070 fbshipit-source-id: e2a7a1a72e2f254c7a8be814e299cd47c5e55738
1 parent fd32631 commit 53f9e51

File tree

43 files changed

+2
-2594
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2
-2594
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -486,16 +486,10 @@ def print_kernels(kernels: Optional[List[str]]) -> List[QuantizeOpBase]:
486486
default=None,
487487
help="If set with grouped mode, repeat input shapes this many times. Comma separated list of groups to benchmark",
488488
)
489-
@click.option(
490-
"--total-K",
491-
default=None,
492-
help="If set, adjusts the K values to sum to this number. "
493-
"This can help simulate real grouped workloads in backward wgrad.",
494-
)
495489
@click.option(
496490
"--total-M",
497491
default=None,
498-
help="If set, adjusts the M values to sum to this number. "
492+
help="If set, Adjusts the M values to sum to this number. "
499493
"This can help simulate real grouped workloads.",
500494
)
501495
@click.option(
@@ -548,7 +542,6 @@ def invoke_main(
548542
pair_nk: bool,
549543
grouped: bool,
550544
groups: Optional[str],
551-
total_k: Optional[str],
552545
total_m: Optional[str],
553546
no_cuda_graph: bool,
554547
use_rotating_buffer_bench: bool,
@@ -560,14 +553,6 @@ def invoke_main(
560553
):
561554
if enable_amd_env_vars:
562555
set_amd_env_vars()
563-
564-
# Validate that total_m and total_k are mutually exclusive
565-
if total_m is not None and total_k is not None:
566-
raise ValueError(
567-
"total_m and total_k cannot be specified at the same time. "
568-
"Please provide only one of them."
569-
)
570-
571556
# If kernel filter is provided, parse it. Else, benchmark all kernels.
572557
all_kernels = kernels.strip().split(",") if kernels else None
573558
quantize_ops = collect_kernels_to_profile(all_kernels)
@@ -644,17 +629,6 @@ def invoke_main(
644629
for g in groups_list
645630
for b, _, n, k in MNK
646631
]
647-
elif total_k:
648-
MNK = [
649-
[
650-
[b] * g,
651-
[m] * g,
652-
[n] * g,
653-
generate_group_tensor(g, int(total_k)),
654-
]
655-
for g in groups_list
656-
for b, m, n, _ in MNK
657-
]
658632
else:
659633
MNK = [
660634
[[b] * g, [m] * g, [n] * g, [k] * g]

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,7 +2084,7 @@ def cuda(self) -> bool:
20842084
@register_quantize_op
20852085
class BF16GroupedGrad(QuantizeOpBase):
20862086
"""
2087-
BF16 grouped matmul with dgrad inputs in pretraining backed by cutlass
2087+
BF16 grouped matmul with grad inputs backed by cutlass
20882088
"""
20892089

20902090
def preprocess(self, x, w):
@@ -2126,52 +2126,6 @@ def cuda(self) -> bool:
21262126
return True
21272127

21282128

2129-
@register_quantize_op
2130-
class BF16GroupedWGrad(QuantizeOpBase):
2131-
"""
2132-
BF16 grouped matmul with wgrad inputs in pretraining backed by cutlass
2133-
"""
2134-
2135-
def preprocess(self, x, w):
2136-
# Get K values for each group
2137-
k_values = [xi.shape[1] for xi in x] # K dimension for each group
2138-
2139-
# Convert k_values into sizes tensor
2140-
k_sizes = torch.tensor(k_values).to(dtype=torch.int64, device=x[0].device)
2141-
2142-
x = torch.concat(x, dim=1).contiguous() # shape: (M, G*K)
2143-
w = torch.concat(w, dim=1).contiguous() # shape: (N, G*K)
2144-
2145-
# Transpose the follows to simulate wgrad shapes
2146-
x = x.t().contiguous() # shape: (G*K, M)
2147-
w = w.t().contiguous() # shape: (G*K, N)
2148-
2149-
# Return processed tensors
2150-
return x, w, k_sizes
2151-
2152-
def quantize(self, x, w, k_sizes):
2153-
return x, w, k_sizes
2154-
2155-
def compute(self, x, w, k_sizes):
2156-
return torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad(x, w, k_sizes)
2157-
2158-
def quantize_and_compute(self, x, w, k_sizes):
2159-
x, w, k_sizes = self.quantize(x, w, k_sizes)
2160-
return self.compute(x, w, k_sizes)
2161-
2162-
@property
2163-
def name(self) -> str:
2164-
return "bf16_grouped_wgrad"
2165-
2166-
@property
2167-
def hip(self) -> bool:
2168-
return False
2169-
2170-
@property
2171-
def cuda(self) -> bool:
2172-
return True
2173-
2174-
21752129
@register_quantize_op
21762130
class BF16GroupedStacked(QuantizeOpBase):
21772131
"""

0 commit comments

Comments
 (0)