Skip to content

Commit dbdae51

Browse files
authored
Support for TMA Epilogue for Group Gemm and add pingpong ptr array & Group Gemm (#1795)
1 parent 21d0534 commit dbdae51

23 files changed

+2359
-347
lines changed

examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu

Lines changed: 75 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -95,40 +95,66 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // M
9595
using ElementAccumulator = float; // Element type for internal accumulation
9696
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
9797
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
98-
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
99-
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
10098
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
101-
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
102-
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
103-
104-
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
105-
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
106-
TileShape, ClusterShape,
107-
cutlass::epilogue::collective::EpilogueTileAuto,
108-
ElementAccumulator, ElementAccumulator,
109-
ElementC, LayoutC, AlignmentC,
110-
ElementC, LayoutC, AlignmentC,
111-
EpilogueSchedule
112-
>::CollectiveOp;
113-
114-
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
115-
ArchTag, OperatorClass,
116-
ElementA, LayoutA, AlignmentA,
117-
ElementB, LayoutB, AlignmentB,
118-
ElementAccumulator,
119-
TileShape, ClusterShape,
120-
cutlass::gemm::collective::StageCountAutoCarveout<
121-
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
122-
KernelSchedule
123-
>::CollectiveOp;
124-
125-
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
126-
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
127-
CollectiveMainloop,
128-
CollectiveEpilogue
129-
>;
130-
131-
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
99+
100+
// Different configs for pingpong/cooperative
101+
struct CooperativeConfig {
102+
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
103+
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
104+
using TileShape = Shape<_256,_128,_64>;
105+
using ClusterShape = Shape<_1,_2,_1>;
106+
};
107+
108+
struct PingpongConfig {
109+
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
110+
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
111+
using TileShape = Shape<_64,_128,_64>;
112+
using ClusterShape = Shape<_1,_1,_1>;
113+
};
114+
115+
template <typename ScheduleConfig>
116+
struct GemmGivenSchedule {
117+
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
118+
using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
119+
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
120+
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
121+
122+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
123+
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
124+
TileShape, ClusterShape,
125+
cutlass::epilogue::collective::EpilogueTileAuto,
126+
ElementAccumulator, ElementAccumulator,
127+
ElementC, LayoutC, AlignmentC,
128+
ElementC, LayoutC, AlignmentC,
129+
EpilogueSchedule
130+
>::CollectiveOp;
131+
132+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
133+
ArchTag, OperatorClass,
134+
ElementA, LayoutA, AlignmentA,
135+
ElementB, LayoutB, AlignmentB,
136+
ElementAccumulator,
137+
TileShape, ClusterShape,
138+
cutlass::gemm::collective::StageCountAutoCarveout<
139+
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
140+
KernelSchedule
141+
>::CollectiveOp;
142+
143+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
144+
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
145+
CollectiveMainloop,
146+
CollectiveEpilogue
147+
>;
148+
149+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
150+
};
151+
152+
using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
153+
using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;
154+
155+
using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
156+
using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;
157+
132158

133159
// Reference device GEMM implementation type
134160
using DeviceGemmReference = cutlass::reference::device::Gemm<
@@ -261,14 +287,14 @@ bool initialize_block(
261287
int bits_input = cutlass::sizeof_bits<Element>::value;
262288

263289
if (bits_input == 1) {
264-
scope_max = 2;
265-
scope_min = 0;
290+
scope_max = static_cast<Element>(2);
291+
scope_min = static_cast<Element>(0);
266292
} else if (bits_input <= 8) {
267-
scope_max = 2;
268-
scope_min = -2;
293+
scope_max = static_cast<Element>(2);
294+
scope_min = static_cast<Element>(-2);
269295
} else {
270-
scope_max = 8;
271-
scope_min = -8;
296+
scope_max = static_cast<Element>(8);
297+
scope_min = static_cast<Element>(-8);
272298
}
273299

274300
cutlass::reference::device::BlockFillRandomUniform(
@@ -351,15 +377,16 @@ void initialize(const Options &options) {
351377
}
352378

353379
/// Populates a Gemm::Arguments structure from the given commandline options
354-
typename Gemm::Arguments args_from_options(const Options &options)
380+
template <typename GemmT>
381+
typename GemmT::Arguments args_from_options(const Options &options)
355382
{
356383
cutlass::KernelHardwareInfo hw_info;
357384
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
358385
// to use a GPU other than that with device ID 0.
359386
hw_info.device_id = 0;
360387
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
361388

362-
typename Gemm::Arguments arguments{
389+
typename GemmT::Arguments arguments{
363390
cutlass::gemm::GemmUniversalMode::kArray,
364391
{{options.m, options.n, options.k, options.l}},
365392
{ptr_A.get(), stride_A, ptr_B.get(), stride_B},
@@ -405,20 +432,20 @@ bool verify(const Options &options) {
405432
}
406433

407434
/// Execute a given example GEMM computation
408-
template <typename Gemm>
435+
template <typename GemmT>
409436
int run(Options &options)
410437
{
411438
allocate(options);
412439
initialize(options);
413440

414441
// Instantiate CUTLASS kernel depending on templates
415-
Gemm gemm;
442+
GemmT gemm;
416443

417444
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
418-
auto arguments = args_from_options(options);
445+
auto arguments = args_from_options<GemmT>(options);
419446

420447
// Using the arguments, query for extra workspace required for matrix multiplication computation
421-
size_t workspace_size = Gemm::get_workspace_size(arguments);
448+
size_t workspace_size = GemmT::get_workspace_size(arguments);
422449

423450
// Allocate workspace memory
424451
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
@@ -510,7 +537,10 @@ int main(int argc, char const **args) {
510537
//
511538

512539
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
540+
std::cout << "\n*** Cooperative schedule ***" << std::endl;
513541
run<Gemm>(options);
542+
std::cout << "\n*** Pingpong schedule ***" << std::endl;
543+
run<GemmPingpong>(options);
514544
#endif
515545

516546
return 0;

examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -117,20 +117,39 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // A
117117
using ElementAccumulator = float; // Element type for internal accumulation
118118
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
119119
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
120-
using TileShape = Shape<_256,_128,_128>; // Threadblock-level tile size
121-
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
122120
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
123-
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
124-
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
125121

126-
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
122+
// Different configs for pingpong/cooperative
123+
struct CooperativeConfig {
124+
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
125+
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
126+
using TileShape = Shape<_256,_128,_128>;
127+
using ClusterShape = Shape<_2,_2,_1>;
128+
};
129+
130+
struct PingpongConfig {
131+
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
132+
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
133+
using TileShape = Shape<_128,_128,_128>;
134+
using ClusterShape = Shape<_2,_1,_1>;
135+
};
136+
137+
template <typename ScheduleConfig>
138+
struct GemmGivenSchedule {
139+
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
140+
using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
141+
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
142+
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
143+
144+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
127145
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
128146
TileShape, ClusterShape,
129147
cutlass::epilogue::collective::EpilogueTileAuto,
130148
ElementAccumulator, ElementAccumulator,
131149
ElementC, LayoutC *, AlignmentC,
132150
ElementC, LayoutC *, AlignmentC,
133-
EpilogueSchedule
151+
EpilogueSchedule,
152+
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
134153
>::CollectiveOp;
135154

136155
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
@@ -144,13 +163,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
144163
KernelSchedule
145164
>::CollectiveOp;
146165

147-
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
148-
ProblemShape,
149-
CollectiveMainloop,
150-
CollectiveEpilogue
151-
>;
166+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
167+
ProblemShape,
168+
CollectiveMainloop,
169+
CollectiveEpilogue
170+
>;
152171

153-
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
172+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
173+
};
174+
175+
using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
176+
using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;
177+
178+
using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
179+
using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;
154180

155181
// Reference device GEMM implementation type
156182
using DeviceGemmReference = cutlass::reference::device::Gemm<
@@ -271,10 +297,10 @@ struct Options {
271297
int n = cmd_line_n;
272298
int k = cmd_line_k;
273299
if (m < 1) {
274-
m = ((rand() % 512) + 1);
300+
m = alignment * ((rand() % 64) + 1);
275301
}
276302
if (n < 1) {
277-
n = ((rand() % 512) + 1);
303+
n = alignment * ((rand() % 64) + 1);
278304
}
279305
if (k < 1) {
280306
k = alignment * ((rand() % 64) + 1);
@@ -521,41 +547,58 @@ void initialize(const Options &options) {
521547
}
522548

523549
/// Populates a Gemm::Arguments structure from the given commandline options
524-
typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
550+
template <typename GemmT>
551+
typename GemmT::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
525552
{
526553
cutlass::KernelHardwareInfo hw_info;
527554
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
528555
// to use a GPU other than that with device ID 0.
529556
hw_info.device_id = 0;
530557
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
531558

532-
typename Gemm::EpilogueOutputOp::Params params;
559+
typename GemmT::Arguments arguments;
560+
decltype(arguments.epilogue.thread) fusion_args;
561+
533562
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
534563
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
535-
params = typename Gemm::EpilogueOutputOp::Params(
536-
ElementAccumulator(options.alpha), ElementAccumulator(options.beta));
564+
fusion_args.alpha = options.alpha;
565+
fusion_args.beta = options.beta;
566+
fusion_args.alpha_ptr = nullptr;
567+
fusion_args.beta_ptr = nullptr;
568+
fusion_args.alpha_ptr_array = nullptr;
569+
fusion_args.beta_ptr_array = nullptr;
570+
// Single alpha and beta for all groups
571+
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
572+
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
537573
}
538574
else {
539575
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
540-
params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get());
576+
fusion_args.alpha = 0;
577+
fusion_args.beta = 0;
578+
fusion_args.alpha_ptr = nullptr;
579+
fusion_args.beta_ptr = nullptr;
580+
fusion_args.alpha_ptr_array = alpha_device.get();
581+
fusion_args.beta_ptr_array = beta_device.get();
582+
// One alpha and beta per each group
583+
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
584+
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
541585
}
542586

543-
typename Gemm::Arguments arguments;
544587
if (host_problem_shapes_available) {
545-
arguments = typename Gemm::Arguments {
588+
arguments = typename GemmT::Arguments {
546589
cutlass::gemm::GemmUniversalMode::kGrouped,
547590
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
548591
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
549-
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
592+
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
550593
hw_info
551594
};
552595
}
553596
else {
554-
arguments = typename Gemm::Arguments {
597+
arguments = typename GemmT::Arguments {
555598
cutlass::gemm::GemmUniversalMode::kGrouped,
556599
{options.groups, problem_sizes.get(), nullptr},
557600
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
558-
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
601+
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
559602
hw_info
560603
};
561604
}
@@ -605,20 +648,20 @@ bool verify(const Options &options) {
605648
}
606649

607650
/// Execute a given example GEMM computation
608-
template <typename Gemm>
651+
template <typename GemmT>
609652
int run(Options &options, bool host_problem_shapes_available = true)
610653
{
611654
allocate(options);
612655
initialize(options);
613656

614657
// Instantiate CUTLASS kernel depending on templates
615-
Gemm gemm;
658+
GemmT gemm;
616659

617660
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
618-
auto arguments = args_from_options(options, host_problem_shapes_available);
661+
auto arguments = args_from_options<GemmT>(options, host_problem_shapes_available);
619662

620663
// Using the arguments, query for extra workspace required for matrix multiplication computation
621-
size_t workspace_size = Gemm::get_workspace_size(arguments);
664+
size_t workspace_size = GemmT::get_workspace_size(arguments);
622665

623666
// Allocate workspace memory
624667
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
@@ -713,8 +756,14 @@ int main(int argc, char const **args) {
713756
//
714757

715758
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
759+
std::cout << "\n*** Cooperative schedule ***" << std::endl;
716760
run<Gemm>(options);
761+
std::cout << "\n*** Cooperative schedule (host problem shapes unavailable) ***" << std::endl;
717762
run<Gemm>(options, false /*host_problem_shapes_available*/);
763+
std::cout << "\n*** Pingpong schedule ***" << std::endl;
764+
run<GemmPingpong>(options);
765+
std::cout << "\n*** Pingpong schedule (host problem shapes unavailable) ***" << std::endl;
766+
run<GemmPingpong>(options, false /*host_problem_shapes_available*/);
718767
#endif
719768

720769
return 0;

examples/57_hopper_grouped_gemm/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
set(TEST_RANDOM --iterations=0) # Random problem sizes
3333
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
3434

35-
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes
35+
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
3636
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
3737

38-
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes
38+
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
3939
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes
4040

4141
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes

0 commit comments

Comments
 (0)