diff --git a/fbgemm_gpu/src/metric_ops/metric_ops.cu b/fbgemm_gpu/src/metric_ops/metric_ops.cu index 70f9b700d3..4a3b0346c9 100644 --- a/fbgemm_gpu/src/metric_ops/metric_ops.cu +++ b/fbgemm_gpu/src/metric_ops/metric_ops.cu @@ -17,6 +17,7 @@ #include "fbgemm_gpu/utils/cuda_prelude.cuh" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/inclusive_sum_scan.cuh" +#include "fbgemm_gpu/utils/kernel_launcher.cuh" #include "metric_ops.h" constexpr int MAX_ENTRIES_PER_BLOCK = 512; @@ -251,28 +252,29 @@ at::Tensor batch_auc( auto max_smem_size = at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; -#define LAUNCH_AUC_KERNEL(pad) \ - typedef cub::BlockScan BlockScan; \ - TORCH_CHECK( \ - sizeof(BlockScan::TempStorage) + \ - ((MAX_ENTRIES_PER_BLOCK * 2 + 3) * sizeof(acc_t)) <= \ - max_smem_size) \ - auc_kernel \ - <<>>( \ - output.data_ptr(), \ - indices.data_ptr(), \ - labels.data_ptr(), \ - weights.data_ptr(), \ - num_blocks > 1 ? block_flags.data_ptr() : nullptr, \ - num_blocks > 1 ? block_sums.data_ptr() : nullptr, \ - num_entries, \ - last_block_num_entries, \ - padded_num_entries_per_block, \ - num_blocks); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); +#define LAUNCH_AUC_KERNEL(pad) \ + typedef cub::BlockScan BlockScan; \ + TORCH_CHECK( \ + sizeof(BlockScan::TempStorage) + \ + ((MAX_ENTRIES_PER_BLOCK * 2 + 3) * sizeof(acc_t)) <= \ + max_smem_size) \ + \ + FBGEMM_LAUNCH_KERNEL( \ + (auc_kernel), \ + dim3(grid_size), \ + dim3(NUM_THREADS_PER_BLOCK), \ + 0, \ + at::cuda::getCurrentCUDAStream(), \ + output.data_ptr(), \ + indices.data_ptr(), \ + labels.data_ptr(), \ + weights.data_ptr(), \ + num_blocks > 1 ? block_flags.data_ptr() : nullptr, \ + num_blocks > 1 ? block_sums.data_ptr() : nullptr, \ + num_entries, \ + last_block_num_entries, \ + padded_num_entries_per_block, \ + num_blocks); AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "auc_wrapper_1", [&] { FBGEMM_DISPATCH_ALL_TYPES(labels.scalar_type(), "auc_wrapper_2", [&] { @@ -285,7 +287,6 @@ at::Tensor batch_auc( } else { LAUNCH_AUC_KERNEL(2) } - C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); });