@@ -117,20 +117,39 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // A
117
117
using ElementAccumulator = float ; // Element type for internal accumulation
118
118
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
119
119
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
120
- using TileShape = Shape<_256,_128,_128>; // Threadblock-level tile size
121
- using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
122
120
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
123
- using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
124
- using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
125
121
126
- using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
122
+ // Different configs for pingpong/cooperative
123
+ struct CooperativeConfig {
124
+ using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
125
+ using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
126
+ using TileShape = Shape<_256,_128,_128>;
127
+ using ClusterShape = Shape<_2,_2,_1>;
128
+ };
129
+
130
+ struct PingpongConfig {
131
+ using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
132
+ using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
133
+ using TileShape = Shape<_128,_128,_128>;
134
+ using ClusterShape = Shape<_2,_1,_1>;
135
+ };
136
+
137
+ template <typename ScheduleConfig>
138
+ struct GemmGivenSchedule {
139
+ using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
140
+ using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
141
+ using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
142
+ using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
143
+
144
+ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
127
145
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
128
146
TileShape, ClusterShape,
129
147
cutlass::epilogue::collective::EpilogueTileAuto,
130
148
ElementAccumulator, ElementAccumulator,
131
149
ElementC, LayoutC *, AlignmentC,
132
150
ElementC, LayoutC *, AlignmentC,
133
- EpilogueSchedule
151
+ EpilogueSchedule,
152
+ cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
134
153
>::CollectiveOp;
135
154
136
155
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
@@ -144,13 +163,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
144
163
KernelSchedule
145
164
>::CollectiveOp;
146
165
147
- using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
148
- ProblemShape,
149
- CollectiveMainloop,
150
- CollectiveEpilogue
151
- >;
166
+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
167
+ ProblemShape,
168
+ CollectiveMainloop,
169
+ CollectiveEpilogue
170
+ >;
152
171
153
- using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
172
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
173
+ };
174
+
175
+ using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
176
+ using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;
177
+
178
+ using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
179
+ using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;
154
180
155
181
// Reference device GEMM implementation type
156
182
using DeviceGemmReference = cutlass::reference::device::Gemm<
@@ -271,10 +297,10 @@ struct Options {
271
297
int n = cmd_line_n;
272
298
int k = cmd_line_k;
273
299
if (m < 1 ) {
274
- m = ((rand () % 512 ) + 1 );
300
+ m = alignment * ((rand () % 64 ) + 1 );
275
301
}
276
302
if (n < 1 ) {
277
- n = ((rand () % 512 ) + 1 );
303
+ n = alignment * ((rand () % 64 ) + 1 );
278
304
}
279
305
if (k < 1 ) {
280
306
k = alignment * ((rand () % 64 ) + 1 );
@@ -521,41 +547,58 @@ void initialize(const Options &options) {
521
547
}
522
548
523
549
// / Populates a Gemm::Arguments structure from the given commandline options
524
- typename Gemm::Arguments args_from_options (const Options &options, bool host_problem_shapes_available = true )
550
+ template <typename GemmT>
551
+ typename GemmT::Arguments args_from_options (const Options &options, bool host_problem_shapes_available = true )
525
552
{
526
553
cutlass::KernelHardwareInfo hw_info;
527
554
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
528
555
// to use a GPU other than that with device ID 0.
529
556
hw_info.device_id = 0 ;
530
557
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count (hw_info.device_id );
531
558
532
- typename Gemm::EpilogueOutputOp::Params params;
559
+ typename GemmT::Arguments arguments;
560
+ decltype (arguments.epilogue .thread ) fusion_args;
561
+
533
562
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
534
563
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
535
- params = typename Gemm::EpilogueOutputOp::Params (
536
- ElementAccumulator (options.alpha ), ElementAccumulator (options.beta ));
564
+ fusion_args.alpha = options.alpha ;
565
+ fusion_args.beta = options.beta ;
566
+ fusion_args.alpha_ptr = nullptr ;
567
+ fusion_args.beta_ptr = nullptr ;
568
+ fusion_args.alpha_ptr_array = nullptr ;
569
+ fusion_args.beta_ptr_array = nullptr ;
570
+ // Single alpha and beta for all groups
571
+ fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0 };
572
+ fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0 };
537
573
}
538
574
else {
539
575
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
540
- params = typename Gemm::EpilogueOutputOp::Params (alpha_device.get (), beta_device.get ());
576
+ fusion_args.alpha = 0 ;
577
+ fusion_args.beta = 0 ;
578
+ fusion_args.alpha_ptr = nullptr ;
579
+ fusion_args.beta_ptr = nullptr ;
580
+ fusion_args.alpha_ptr_array = alpha_device.get ();
581
+ fusion_args.beta_ptr_array = beta_device.get ();
582
+ // One alpha and beta per each group
583
+ fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1 };
584
+ fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1 };
541
585
}
542
586
543
- typename Gemm::Arguments arguments;
544
587
if (host_problem_shapes_available) {
545
- arguments = typename Gemm ::Arguments {
588
+ arguments = typename GemmT ::Arguments {
546
589
cutlass::gemm::GemmUniversalMode::kGrouped ,
547
590
{options.groups , problem_sizes.get (), options.problem_sizes_host .data ()},
548
591
{ptr_A.get (), stride_A.get (), ptr_B.get (), stride_B.get ()},
549
- {params , ptr_C.get (), stride_C.get (), ptr_D.get (), stride_D.get ()},
592
+ {fusion_args , ptr_C.get (), stride_C.get (), ptr_D.get (), stride_D.get ()},
550
593
hw_info
551
594
};
552
595
}
553
596
else {
554
- arguments = typename Gemm ::Arguments {
597
+ arguments = typename GemmT ::Arguments {
555
598
cutlass::gemm::GemmUniversalMode::kGrouped ,
556
599
{options.groups , problem_sizes.get (), nullptr },
557
600
{ptr_A.get (), stride_A.get (), ptr_B.get (), stride_B.get ()},
558
- {params , ptr_C.get (), stride_C.get (), ptr_D.get (), stride_D.get ()},
601
+ {fusion_args , ptr_C.get (), stride_C.get (), ptr_D.get (), stride_D.get ()},
559
602
hw_info
560
603
};
561
604
}
@@ -605,20 +648,20 @@ bool verify(const Options &options) {
605
648
}
606
649
607
650
// / Execute a given example GEMM computation
608
- template <typename Gemm >
651
+ template <typename GemmT >
609
652
int run (Options &options, bool host_problem_shapes_available = true )
610
653
{
611
654
allocate (options);
612
655
initialize (options);
613
656
614
657
// Instantiate CUTLASS kernel depending on templates
615
- Gemm gemm;
658
+ GemmT gemm;
616
659
617
660
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
618
- auto arguments = args_from_options (options, host_problem_shapes_available);
661
+ auto arguments = args_from_options<GemmT> (options, host_problem_shapes_available);
619
662
620
663
// Using the arguments, query for extra workspace required for matrix multiplication computation
621
- size_t workspace_size = Gemm ::get_workspace_size (arguments);
664
+ size_t workspace_size = GemmT ::get_workspace_size (arguments);
622
665
623
666
// Allocate workspace memory
624
667
cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
@@ -713,8 +756,14 @@ int main(int argc, char const **args) {
713
756
//
714
757
715
758
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
759
+ std::cout << " \n *** Cooperative schedule ***" << std::endl;
716
760
run<Gemm>(options);
761
+ std::cout << " \n *** Cooperative schedule (host problem shapes unavailable) ***" << std::endl;
717
762
run<Gemm>(options, false /* host_problem_shapes_available*/ );
763
+ std::cout << " \n *** Pingpong schedule ***" << std::endl;
764
+ run<GemmPingpong>(options);
765
+ std::cout << " \n *** Pingpong schedule (host problem shapes unavailable) ***" << std::endl;
766
+ run<GemmPingpong>(options, false /* host_problem_shapes_available*/ );
718
767
#endif
719
768
720
769
return 0 ;
0 commit comments