24
24
#endif
25
25
26
26
#include < fbgemm_gpu/utils/vec_quant.cuh>
27
+ #include " fbgemm_gpu/utils/kernel_launcher.cuh"
27
28
28
29
template <typename func_t >
29
30
void set_gpu_max_dynamic_shared_memory (
@@ -1447,23 +1448,27 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_wmma_impl(
1447
1448
1448
1449
#define CALL_GQA_ATTN_SPLITK_WMMA ( \
1449
1450
CACHE_TYPE, NUM_GROUPS, KV_LOAD_T, KV_DATA_TYPE) \
1450
- const auto gqa_fn = gqa_attn_splitk_wmma_kernel< \
1451
+ auto gqa_fn = gqa_attn_splitk_wmma_kernel< \
1451
1452
CACHE_TYPE, \
1452
1453
NUM_GROUPS, \
1453
1454
KV_LOAD_T, \
1454
1455
KV_DATA_TYPE>; \
1455
1456
if (smem > SMEM_ADJUST_THRESHOLD) { \
1456
1457
set_gpu_max_dynamic_shared_memory (gqa_fn, smem, XQ.get_device ()); \
1457
1458
} \
1458
- gqa_fn<<<blocks, threads, smem, at::cuda::getCurrentCUDAStream()>>> ( \
1459
+ FBGEMM_LAUNCH_KERNEL ( \
1460
+ (gqa_fn), \
1461
+ blocks, \
1462
+ threads, \
1463
+ smem, \
1464
+ at::cuda::getCurrentCUDAStream (), \
1459
1465
XQ.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(), \
1460
1466
cache_K.packed_accessor64 <CACHE_TYPE, 4 , at::RestrictPtrTraits>(), \
1461
1467
cache_V.packed_accessor64 <CACHE_TYPE, 4 , at::RestrictPtrTraits>(), \
1462
1468
out_splitK.packed_accessor32 <float , 4 , at::RestrictPtrTraits>(), \
1463
1469
seq_positions.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(), \
1464
1470
metadata.packed_accessor32 <float , 4 , at::RestrictPtrTraits>(), \
1465
- qk_scale); \
1466
- C10_CUDA_KERNEL_LAUNCH_CHECK ()
1471
+ qk_scale);
1467
1472
1468
1473
if (cache_K.dtype () == at::kBFloat16 ) {
1469
1474
CALL_GQA_ATTN_SPLITK_WMMA (
@@ -1486,16 +1491,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_wmma_impl(
1486
1491
1487
1492
#undef CALL_GQA_ATTN_SPLITK_WMMA
1488
1493
1489
- gqa_attn_splitk_reduce_wmma_kernel<<<
1494
+ FBGEMM_LAUNCH_KERNEL (
1495
+ (gqa_attn_splitk_reduce_wmma_kernel),
1490
1496
dim3 (B, H),
1491
1497
dim3 (kThreadsPerWarp , D_H / kThreadsPerWarp ),
1492
1498
0 ,
1493
- at::cuda::getCurrentCUDAStream()>>>(
1499
+ at::cuda::getCurrentCUDAStream (),
1494
1500
out_splitK.packed_accessor32 <float , 4 , at::RestrictPtrTraits>(),
1495
1501
metadata.packed_accessor32 <float , 4 , at::RestrictPtrTraits>(),
1496
1502
seq_positions.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1497
1503
O.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>());
1498
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1499
1504
return {O, out_splitK, metadata};
1500
1505
}
1501
1506
@@ -1545,30 +1550,32 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_impl(
1545
1550
dim3 threads (kThreadsPerWarp , kWarpsPerBlock );
1546
1551
1547
1552
if (cache_K.dtype () == at::kBFloat16 ) {
1548
- gqa_attn_splitk_qk_kernel<<<
1553
+ FBGEMM_LAUNCH_KERNEL (
1554
+ (gqa_attn_splitk_qk_kernel),
1549
1555
blocks,
1550
1556
threads,
1551
1557
0 ,
1552
- at::cuda::getCurrentCUDAStream ()>>>(
1558
+ at::cuda::getCurrentCUDAStream (),
1553
1559
XQ.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1554
1560
cache_K.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1555
1561
seq_positions.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1556
1562
QK_out.packed_accessor32 <float , 3 , at::RestrictPtrTraits>());
1557
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1558
1563
} else {
1559
- #define CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL (NUM_GROUPS, ...) \
1560
- gqa_attn_splitk_qk_int4_kernel<NUM_GROUPS> \
1561
- <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> ( \
1562
- XQ.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(), \
1563
- cache_K.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(), \
1564
- seq_positions \
1565
- .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(), \
1566
- QK_out.packed_accessor32 <float , 3 , at::RestrictPtrTraits>());
1564
+ #define CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL (NUM_GROUPS, ...) \
1565
+ FBGEMM_LAUNCH_KERNEL ( \
1566
+ (gqa_attn_splitk_qk_int4_kernel<NUM_GROUPS>), \
1567
+ blocks, \
1568
+ threads, \
1569
+ 0 , \
1570
+ at::cuda::getCurrentCUDAStream (), \
1571
+ XQ.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(), \
1572
+ cache_K.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(), \
1573
+ seq_positions.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(), \
1574
+ QK_out.packed_accessor32 <float , 3 , at::RestrictPtrTraits>());
1567
1575
1568
1576
auto num_groups_ = num_groups ? num_groups.value () : 1 ;
1569
1577
CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK (
1570
1578
CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL, num_groups_);
1571
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1572
1579
1573
1580
#undef CALL_MQA_ATTN_SPLITK_QK_INT4_GROUPWISE_KERNEL
1574
1581
}
@@ -1589,16 +1596,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_impl(
1589
1596
gqa_attn_splitk_attn_kernel, smem, device);
1590
1597
}
1591
1598
1592
- gqa_attn_splitk_attn_kernel<<<
1599
+ FBGEMM_LAUNCH_KERNEL (
1600
+ (gqa_attn_splitk_attn_kernel),
1593
1601
blocks,
1594
1602
threads,
1595
1603
smem,
1596
- at::cuda::getCurrentCUDAStream ()>>>(
1604
+ at::cuda::getCurrentCUDAStream (),
1597
1605
QK_out.packed_accessor32 <float , 3 , at::RestrictPtrTraits>(),
1598
1606
attn_out.packed_accessor32 <float , 3 , at::RestrictPtrTraits>(),
1599
1607
seq_positions.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1600
1608
qk_scale);
1601
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1602
1609
}
1603
1610
auto O = at::empty ({split_k, B, 1 , H, D_H}, XQ.options ().dtype (at::kFloat ));
1604
1611
{
@@ -1614,16 +1621,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_impl(
1614
1621
set_gpu_max_dynamic_shared_memory (
1615
1622
gqa_attn_splitk_v_kernel, smem, device);
1616
1623
}
1617
- gqa_attn_splitk_v_kernel<<<
1624
+ FBGEMM_LAUNCH_KERNEL (
1625
+ (gqa_attn_splitk_v_kernel),
1618
1626
blocks,
1619
1627
threads,
1620
1628
smem,
1621
- at::cuda::getCurrentCUDAStream ()>>>(
1629
+ at::cuda::getCurrentCUDAStream (),
1622
1630
attn_out.packed_accessor32 <float , 3 , at::RestrictPtrTraits>(),
1623
1631
cache_V.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
1624
1632
O.packed_accessor32 <float , 5 , at::RestrictPtrTraits>(),
1625
1633
seq_positions.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>());
1626
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1627
1634
} else {
1628
1635
#define CALL_MQA_ATTN_SPLITKV_INT4_GROUPWISE_KERNEL (NUM_GROUPS, ...) \
1629
1636
if (set_max_dynamic_smem) { \
0 commit comments