Skip to content

Commit 58118b3

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI gqa attn splitk kernels to FBGEMM_LAUNCH_KERNEL, pt 1
Summary: - Migrate GenAI gqa attn splitk kernels to `FBGEMM_LAUNCH_KERNEL`, pt 1 Reviewed By: r-barnes Differential Revision: D81817544
1 parent 53f9e51 commit 58118b3

File tree

2 files changed

+35
-28
lines changed

2 files changed

+35
-28
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/gqa_attn_splitk.cu

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#endif
2525

2626
#include <fbgemm_gpu/utils/vec_quant.cuh>
27+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
2728

2829
template <typename func_t>
2930
void set_gpu_max_dynamic_shared_memory(
@@ -1447,23 +1448,27 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_wmma_impl(
14471448

14481449
#define CALL_GQA_ATTN_SPLITK_WMMA( \
14491450
CACHE_TYPE, NUM_GROUPS, KV_LOAD_T, KV_DATA_TYPE) \
1450-
const auto gqa_fn = gqa_attn_splitk_wmma_kernel< \
1451+
auto gqa_fn = gqa_attn_splitk_wmma_kernel< \
14511452
CACHE_TYPE, \
14521453
NUM_GROUPS, \
14531454
KV_LOAD_T, \
14541455
KV_DATA_TYPE>; \
14551456
if (smem > SMEM_ADJUST_THRESHOLD) { \
14561457
set_gpu_max_dynamic_shared_memory(gqa_fn, smem, XQ.get_device()); \
14571458
} \
1458-
gqa_fn<<<blocks, threads, smem, at::cuda::getCurrentCUDAStream()>>>( \
1459+
FBGEMM_LAUNCH_KERNEL( \
1460+
(gqa_fn), \
1461+
blocks, \
1462+
threads, \
1463+
smem, \
1464+
at::cuda::getCurrentCUDAStream(), \
14591465
XQ.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(), \
14601466
cache_K.packed_accessor64<CACHE_TYPE, 4, at::RestrictPtrTraits>(), \
14611467
cache_V.packed_accessor64<CACHE_TYPE, 4, at::RestrictPtrTraits>(), \
14621468
out_splitK.packed_accessor32<float, 4, at::RestrictPtrTraits>(), \
14631469
seq_positions.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), \
14641470
metadata.packed_accessor32<float, 4, at::RestrictPtrTraits>(), \
1465-
qk_scale); \
1466-
C10_CUDA_KERNEL_LAUNCH_CHECK()
1471+
qk_scale);
14671472

14681473
if (cache_K.dtype() == at::kBFloat16) {
14691474
CALL_GQA_ATTN_SPLITK_WMMA(
@@ -1486,16 +1491,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_wmma_impl(
14861491

14871492
#undef CALL_GQA_ATTN_SPLITK_WMMA
14881493

1489-
gqa_attn_splitk_reduce_wmma_kernel<<<
1494+
FBGEMM_LAUNCH_KERNEL(
1495+
(gqa_attn_splitk_reduce_wmma_kernel),
14901496
dim3(B, H),
14911497
dim3(kThreadsPerWarp, D_H / kThreadsPerWarp),
14921498
0,
1493-
at::cuda::getCurrentCUDAStream()>>>(
1499+
at::cuda::getCurrentCUDAStream(),
14941500
out_splitK.packed_accessor32<float, 4, at::RestrictPtrTraits>(),
14951501
metadata.packed_accessor32<float, 4, at::RestrictPtrTraits>(),
14961502
seq_positions.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
14971503
O.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>());
1498-
C10_CUDA_KERNEL_LAUNCH_CHECK();
14991504
return {O, out_splitK, metadata};
15001505
}
15011506

@@ -1545,30 +1550,32 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_impl(
15451550
dim3 threads(kThreadsPerWarp, kWarpsPerBlock);
15461551

15471552
if (cache_K.dtype() == at::kBFloat16) {
1548-
gqa_attn_splitk_qk_kernel<<<
1553+
FBGEMM_LAUNCH_KERNEL(
1554+
(gqa_attn_splitk_qk_kernel),
15491555
blocks,
15501556
threads,
15511557
0,
1552-
at::cuda::getCurrentCUDAStream()>>>(
1558+
at::cuda::getCurrentCUDAStream(),
15531559
XQ.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(),
15541560
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
15551561
seq_positions.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
15561562
QK_out.packed_accessor32<float, 3, at::RestrictPtrTraits>());
1557-
C10_CUDA_KERNEL_LAUNCH_CHECK();
15581563
} else {
1559-
#define CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL(NUM_GROUPS, ...) \
1560-
gqa_attn_splitk_qk_int4_kernel<NUM_GROUPS> \
1561-
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
1562-
XQ.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(), \
1563-
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
1564-
seq_positions \
1565-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), \
1566-
QK_out.packed_accessor32<float, 3, at::RestrictPtrTraits>());
1564+
#define CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL(NUM_GROUPS, ...) \
1565+
FBGEMM_LAUNCH_KERNEL( \
1566+
(gqa_attn_splitk_qk_int4_kernel<NUM_GROUPS>), \
1567+
blocks, \
1568+
threads, \
1569+
0, \
1570+
at::cuda::getCurrentCUDAStream(), \
1571+
XQ.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(), \
1572+
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
1573+
seq_positions.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), \
1574+
QK_out.packed_accessor32<float, 3, at::RestrictPtrTraits>());
15671575

15681576
auto num_groups_ = num_groups ? num_groups.value() : 1;
15691577
CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK(
15701578
CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL, num_groups_);
1571-
C10_CUDA_KERNEL_LAUNCH_CHECK();
15721579

15731580
#undef CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL
15741581
}
@@ -1589,16 +1596,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_impl(
15891596
gqa_attn_splitk_attn_kernel, smem, device);
15901597
}
15911598

1592-
gqa_attn_splitk_attn_kernel<<<
1599+
FBGEMM_LAUNCH_KERNEL(
1600+
(gqa_attn_splitk_attn_kernel),
15931601
blocks,
15941602
threads,
15951603
smem,
1596-
at::cuda::getCurrentCUDAStream()>>>(
1604+
at::cuda::getCurrentCUDAStream(),
15971605
QK_out.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
15981606
attn_out.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
15991607
seq_positions.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
16001608
qk_scale);
1601-
C10_CUDA_KERNEL_LAUNCH_CHECK();
16021609
}
16031610
auto O = at::empty({split_k, B, 1, H, D_H}, XQ.options().dtype(at::kFloat));
16041611
{
@@ -1614,16 +1621,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_impl(
16141621
set_gpu_max_dynamic_shared_memory(
16151622
gqa_attn_splitk_v_kernel, smem, device);
16161623
}
1617-
gqa_attn_splitk_v_kernel<<<
1624+
FBGEMM_LAUNCH_KERNEL(
1625+
(gqa_attn_splitk_v_kernel),
16181626
blocks,
16191627
threads,
16201628
smem,
1621-
at::cuda::getCurrentCUDAStream()>>>(
1629+
at::cuda::getCurrentCUDAStream(),
16221630
attn_out.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
16231631
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
16241632
O.packed_accessor32<float, 5, at::RestrictPtrTraits>(),
16251633
seq_positions.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
1626-
C10_CUDA_KERNEL_LAUNCH_CHECK();
16271634
} else {
16281635
#define CALL_MQA_ATTN_SPLITKV_INT4_GROUPWISE_KERNEL(NUM_GROUPS, ...) \
16291636
if (set_max_dynamic_smem) { \

fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ struct KernelLauncher {
443443
#define FBGEMM_LAUNCH_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
444444
([&] { \
445445
constexpr auto context = SOURCE_CONTEXT_CURRENT(KERNEL); \
446-
const auto& kernel = KERNEL; \
446+
auto& kernel = KERNEL; \
447447
\
448448
return fbgemm_gpu::utils:: \
449449
KernelLauncher<false, _FKL_BLOCKING_, _FKL_TENSORCHECK_>(context) \
@@ -453,7 +453,7 @@ struct KernelLauncher {
453453
#define FBGEMM_LAUNCH_DSA_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
454454
([&] { \
455455
constexpr auto context = SOURCE_CONTEXT_CURRENT(KERNEL); \
456-
decltype(KERNEL)& kernel = KERNEL; \
456+
auto& kernel = KERNEL; \
457457
\
458458
return fbgemm_gpu::utils:: \
459459
KernelLauncher<true, _FKL_BLOCKING_, _FKL_TENSORCHECK_>(context) \
@@ -463,7 +463,7 @@ struct KernelLauncher {
463463
#define FBGEMM_TIME_KERNEL_RUN(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
464464
([&] { \
465465
constexpr auto context = SOURCE_CONTEXT_CURRENT(KERNEL); \
466-
decltype(KERNEL)& kernel = KERNEL; \
466+
auto& kernel = KERNEL; \
467467
\
468468
return fbgemm_gpu::utils:: \
469469
KernelLauncher<false, _FKL_BLOCKING_, _FKL_TENSORCHECK_, true>( \

0 commit comments

Comments
 (0)