@@ -490,13 +490,15 @@ def print_kernels(kernels: Optional[List[str]]) -> List[QuantizeOpBase]:
490
490
"--total-K" ,
491
491
default = None ,
492
492
help = "If set, adjusts the K values to sum to this number. "
493
- "This can help simulate real grouped workloads in backward wgrad." ,
493
+ "This can help simulate real grouped workloads in backward wgrad. "
494
+ "Comma separated list of total-K values to benchmark." ,
494
495
)
495
496
@click .option (
496
497
"--total-M" ,
497
498
default = None ,
498
499
help = "If set, adjusts the M values to sum to this number. "
499
- "This can help simulate real grouped workloads." ,
500
+ "This can help simulate real grouped workloads."
501
+ "Comma separated list of total-M values to benchmark." ,
500
502
)
501
503
@click .option (
502
504
"--no-cuda-graph" ,
@@ -634,25 +636,29 @@ def invoke_main(
634
636
if groups :
635
637
groups_list = [int (g ) for g in groups .strip ().split ("," )]
636
638
if total_m :
639
+ total_m_list = [int (tm ) for tm in total_m .strip ().split ("," )]
637
640
MNK = [
638
641
[
639
642
[b ] * g ,
640
- generate_group_tensor (g , int ( total_m ) ),
643
+ generate_group_tensor (g , tm ),
641
644
[n ] * g ,
642
645
[k ] * g ,
643
646
]
644
647
for g in groups_list
648
+ for tm in total_m_list
645
649
for b , _ , n , k in MNK
646
650
]
647
651
elif total_k :
652
+ total_k_list = [int (tk ) for tk in total_k .strip ().split ("," )]
648
653
MNK = [
649
654
[
650
655
[b ] * g ,
651
656
[m ] * g ,
652
657
[n ] * g ,
653
- generate_group_tensor (g , int ( total_k ) ),
658
+ generate_group_tensor (g , tk ),
654
659
]
655
660
for g in groups_list
661
+ for tk in total_k_list
656
662
for b , m , n , _ in MNK
657
663
]
658
664
else :
0 commit comments