From 673c80bb5e3e5cdfd23cc5ff221e1c5d1cc7c495 Mon Sep 17 00:00:00 2001 From: Sudharshan Govindan Date: Fri, 25 Jul 2025 01:41:07 +0000 Subject: [PATCH 1/5] Changed fetching of warp threads for ROCM --- csrc/selective_scan/reverse_scan.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/selective_scan/reverse_scan.cuh b/csrc/selective_scan/reverse_scan.cuh index d19397879..3c6a10e23 100644 --- a/csrc/selective_scan/reverse_scan.cuh +++ b/csrc/selective_scan/reverse_scan.cuh @@ -101,7 +101,7 @@ struct WarpReverseScan { #ifndef USE_ROCM #define WARP_THREADS CUB_WARP_THREADS(0) #else - #define WARP_THREADS HIPCUB_WARP_THREADS + #define WARP_THREADS HIPCUB_DEVICE_WARP_THREADS #endif static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS); /// The number of warp scan steps From f13569434cccad64ecf4e152c3a249391807e8e4 Mon Sep 17 00:00:00 2001 From: Sudharshan Govindan Date: Fri, 25 Jul 2025 17:15:42 +0000 Subject: [PATCH 2/5] Changed comments --- csrc/selective_scan/reverse_scan.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/selective_scan/reverse_scan.cuh b/csrc/selective_scan/reverse_scan.cuh index 3c6a10e23..4e91c0b14 100644 --- a/csrc/selective_scan/reverse_scan.cuh +++ b/csrc/selective_scan/reverse_scan.cuh @@ -96,7 +96,7 @@ struct WarpReverseScan { /// Whether the logical warp size and the PTX warp size coincide - // In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size() + // In hipcub, warp_threads is defined as HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size() // While in cub, it's defined as a macro that takes a redundant unused argument. #ifndef USE_ROCM #define WARP_THREADS CUB_WARP_THREADS(0) From 1cad0419cef76c20e4cf3cfc49a30c7d9f772c8a Mon Sep 17 00:00:00 2001 From: Sudharshan Govindan Date: Fri, 25 Jul 2025 23:35:33 +0000 Subject: [PATCH 3/5] Change because device warp size will also be deprecated --- csrc/selective_scan/reverse_scan.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/selective_scan/reverse_scan.cuh b/csrc/selective_scan/reverse_scan.cuh index 4e91c0b14..e1369cc0b 100644 --- a/csrc/selective_scan/reverse_scan.cuh +++ b/csrc/selective_scan/reverse_scan.cuh @@ -96,12 +96,10 @@ struct WarpReverseScan { /// Whether the logical warp size and the PTX warp size coincide - // In hipcub, warp_threads is defined as HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size() - // While in cub, it's defined as a macro that takes a redundant unused argument. #ifndef USE_ROCM #define WARP_THREADS CUB_WARP_THREADS(0) #else - #define WARP_THREADS HIPCUB_DEVICE_WARP_THREADS + #define WARP_THREADS rocprim::arch::wavefront::max_size() #endif static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS); /// The number of warp scan steps From ef8c6ea5a10f0ea1cfdd2f17d9a876322d8483b8 Mon Sep 17 00:00:00 2001 From: Sudharshan Govindan Date: Wed, 30 Jul 2025 18:10:28 +0000 Subject: [PATCH 4/5] Added check for rocm major version to swap between ways of getting WARP_THREADS --- csrc/selective_scan/reverse_scan.cuh | 6 +++++- setup.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/selective_scan/reverse_scan.cuh b/csrc/selective_scan/reverse_scan.cuh index e1369cc0b..c40834584 100644 --- a/csrc/selective_scan/reverse_scan.cuh +++ b/csrc/selective_scan/reverse_scan.cuh @@ -99,7 +99,11 @@ struct WarpReverseScan { #ifndef USE_ROCM #define WARP_THREADS CUB_WARP_THREADS(0) #else - #define WARP_THREADS rocprim::arch::wavefront::max_size() + #if ROCM_MAJOR_VERSION >= 7 + #define WARP_THREADS rocprim::arch::wavefront::max_size() + #else + #define WARP_THREADS HIPCUB_WARP_THREAD + #endif #endif static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS); /// The number of warp scan steps diff --git a/setup.py b/setup.py index f61ca90d3..2ed98fad9 100755 --- a/setup.py +++ b/setup.py @@ -157,6 +157,7 @@ def append_nvcc_threads(nvcc_extra_args): "Refer to the README.md for detailed instructions.", UserWarning ) + cc_flag.append(f"-DROCM_MAJOR_VERSION={hip_version.major}") cc_flag.append("-DBUILD_PYTHON_PACKAGE") From a495f794532b26441969dc9cde5f9b6c8440dd77 Mon Sep 17 00:00:00 2001 From: Sudharshan Govindan Date: Wed, 30 Jul 2025 18:23:18 +0000 Subject: [PATCH 5/5] bug fix for HIPCUB_WARP_THREADS --- csrc/selective_scan/reverse_scan.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/selective_scan/reverse_scan.cuh b/csrc/selective_scan/reverse_scan.cuh index c40834584..0f4346b20 100644 --- a/csrc/selective_scan/reverse_scan.cuh +++ b/csrc/selective_scan/reverse_scan.cuh @@ -102,7 +102,7 @@ struct WarpReverseScan { #if ROCM_MAJOR_VERSION >= 7 #define WARP_THREADS rocprim::arch::wavefront::max_size() #else - #define WARP_THREADS HIPCUB_WARP_THREAD + #define WARP_THREADS HIPCUB_WARP_THREADS #endif #endif static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS);