Skip to content

Commit 565879a

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Support multiple total-k and total-m in quantize bench
Summary: X-link: facebookresearch/FBGEMM#1915 Pretty helpful for various grouped gemm tuning, especially wgrad grouped gemm Differential Revision: D82700396
1 parent fd32631 commit 565879a

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,13 +490,15 @@ def print_kernels(kernels: Optional[List[str]]) -> List[QuantizeOpBase]:
490490
"--total-K",
491491
default=None,
492492
help="If set, adjusts the K values to sum to this number. "
493-
"This can help simulate real grouped workloads in backward wgrad.",
493+
"This can help simulate real grouped workloads in backward wgrad. "
494+
"Comma separated list of total-K values to benchmark.",
494495
)
495496
@click.option(
496497
"--total-M",
497498
default=None,
498499
help="If set, adjusts the M values to sum to this number. "
499-
"This can help simulate real grouped workloads.",
500+
"This can help simulate real grouped workloads."
501+
"Comma separated list of total-M values to benchmark.",
500502
)
501503
@click.option(
502504
"--no-cuda-graph",
@@ -634,25 +636,29 @@ def invoke_main(
634636
if groups:
635637
groups_list = [int(g) for g in groups.strip().split(",")]
636638
if total_m:
639+
total_m_list = [int(tm) for tm in total_m.strip().split(",")]
637640
MNK = [
638641
[
639642
[b] * g,
640-
generate_group_tensor(g, int(total_m)),
643+
generate_group_tensor(g, tm),
641644
[n] * g,
642645
[k] * g,
643646
]
644647
for g in groups_list
648+
for tm in total_m_list
645649
for b, _, n, k in MNK
646650
]
647651
elif total_k:
652+
total_k_list = [int(tk) for tk in total_k.strip().split(",")]
648653
MNK = [
649654
[
650655
[b] * g,
651656
[m] * g,
652657
[n] * g,
653-
generate_group_tensor(g, int(total_k)),
658+
generate_group_tensor(g, tk),
654659
]
655660
for g in groups_list
661+
for tk in total_k_list
656662
for b, m, n, _ in MNK
657663
]
658664
else:

0 commit comments

Comments
 (0)