Skip to content

Commit 5e6dc8e

Browse files
committed
Add hipsolver bindings and enable lobpcg kernel tests for HIP
1 parent dc89057 commit 5e6dc8e

File tree

18 files changed

+520
-10
lines changed

18 files changed

+520
-10
lines changed

cmake/hip.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ find_package(hipsparse REQUIRED)
3333
find_package(rocrand REQUIRED)
3434
find_package(rocthrust REQUIRED)
3535
find_package(ROCTX)
36+
if(GINKGO_BUILD_LAPACK)
37+
find_package(hipsolver REQUIRED)
38+
endif()
3639

3740
if(GINKGO_HIP_AMD_UNSAFE_ATOMIC AND GINKGO_HIP_VERSION VERSION_GREATER_EQUAL 5)
3841
set(CMAKE_HIP_FLAGS

common/cuda_hip/base/dev_lapack_bindings.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88

99
#if defined(GKO_COMPILING_CUDA)
1010
#include "cuda/base/cusolver_bindings.hpp"
11+
#define GKO_DEV_LAPACK_ERROR GKO_CUSOLVER_ERROR
12+
#define DEV_LAPACK_INTERNAL_ERROR CUSOLVER_STATUS_INTERNAL_ERROR
1113
#elif defined(GKO_COMPILING_HIP)
1214
#include "hip/base/hipsolver_bindings.hip.hpp"
15+
#define GKO_DEV_LAPACK_ERROR GKO_HIPSOLVER_ERROR
16+
#define DEV_LAPACK_INTERNAL_ERROR HIPSOLVER_STATUS_INTERNAL_ERROR
1317
#else
1418
#error "Executor definition missing"
1519
#endif

common/cuda_hip/eigensolver/lobpcg_kernels.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ void symm_eig(std::shared_ptr<const DefaultExecutor> exec,
125125

126126
int32 host_info = exec->copy_val_to_host(dev_info.get_data());
127127
if (host_info != 0) {
128-
throw GKO_CUSOLVER_ERROR(CUSOLVER_STATUS_INTERNAL_ERROR);
128+
throw GKO_DEV_LAPACK_ERROR(DEV_LAPACK_INTERNAL_ERROR);
129129
}
130130
} catch (std::exception& e) {
131131
std::cout << e.what() << std::endl;
@@ -196,7 +196,7 @@ void symm_generalized_eig(std::shared_ptr<const DefaultExecutor> exec,
196196

197197
int32 host_info = exec->copy_val_to_host(dev_info.get_data());
198198
if (host_info != 0) {
199-
throw GKO_CUSOLVER_ERROR(CUSOLVER_STATUS_INTERNAL_ERROR);
199+
throw GKO_DEV_LAPACK_ERROR(DEV_LAPACK_INTERNAL_ERROR);
200200
}
201201
} catch (std::exception& e) {
202202
std::cout << e.what() << std::endl;
@@ -260,7 +260,7 @@ void b_orthonormalize(std::shared_ptr<const DefaultExecutor> exec,
260260

261261
int32 host_info = exec->copy_val_to_host(dev_info.get_data());
262262
if (host_info != 0) {
263-
throw GKO_CUSOLVER_ERROR(CUSOLVER_STATUS_INTERNAL_ERROR);
263+
throw GKO_DEV_LAPACK_ERROR(DEV_LAPACK_INTERNAL_ERROR);
264264
}
265265
} catch (std::exception& e) {
266266
std::cout << e.what() << std::endl;
@@ -287,7 +287,7 @@ void b_orthonormalize(std::shared_ptr<const DefaultExecutor> exec,
287287

288288
int32 host_info = exec->copy_val_to_host(dev_info.get_data());
289289
if (host_info != 0) {
290-
throw GKO_CUSOLVER_ERROR(CUSOLVER_STATUS_INTERNAL_ERROR);
290+
throw GKO_DEV_LAPACK_ERROR(DEV_LAPACK_INTERNAL_ERROR);
291291
}
292292
} catch (std::exception& e) {
293293
std::cout << e.what() << std::endl;

core/device_hooks/hip_hooks.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -174,6 +174,12 @@ std::string HipsparseError::get_error(int64)
174174
}
175175

176176

177+
std::string HipsolverError::get_error(int64)
178+
{
179+
return "ginkgo HIP module is not compiled";
180+
}
181+
182+
177183
std::string HipfftError::get_error(int64)
178184
{
179185
return "ginkgo HIP module is not compiled";

core/test/base/exception.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@ TEST(ExceptionClasses, HipsparseErrorReturnsCorrectWhatMessage)
127127
}
128128

129129

130+
#if GKO_HAVE_LAPACK
131+
TEST(ExceptionClasses, HipsolverErrorReturnsCorrectWhatMessage)
132+
{
133+
gko::HipsolverError error("test_file.cpp", 123, "test_func", 1);
134+
std::string expected = "test_file.cpp:123: test_func: ";
135+
ASSERT_EQ(expected, std::string(error.what()).substr(0, expected.size()));
136+
}
137+
#endif
138+
139+
130140
TEST(ExceptionClasses, HipfftErrorReturnsCorrectWhatMessage)
131141
{
132142
gko::HipfftError error("test_file.cpp", 123, "test_func", 1);

core/test/base/exception_helpers.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,16 @@ TEST(HipError, ReturnsHipsparseError)
153153
}
154154

155155

156+
#if GKO_HAVE_LAPACK
157+
void throws_hipsolver_error() { throw GKO_HIPSOLVER_ERROR(0); }
158+
159+
TEST(HipError, ReturnsHipsolverError)
160+
{
161+
ASSERT_THROW(throws_hipsolver_error(), gko::HipsolverError);
162+
}
163+
#endif
164+
165+
156166
void throws_hipfft_error() { throw GKO_HIPFFT_ERROR(0); }
157167

158168
TEST(HipError, ReturnsHipfftError)

hip/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ endif()
120120
if(GINKGO_HAVE_ROCTX)
121121
target_link_libraries(ginkgo_hip PRIVATE roc::roctx)
122122
endif()
123+
if(GINKGO_BUILD_LAPACK)
124+
target_link_libraries(ginkgo_hip PRIVATE roc::hipsolver)
125+
endif()
123126

124127
target_compile_options(
125128
ginkgo_hip

hip/base/exception.hip.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -11,10 +11,16 @@
1111
#include <hipblas/hipblas.h>
1212
#include <hiprand/hiprand.h>
1313
#include <hipsparse/hipsparse.h>
14+
#if GKO_HAVE_LAPACK
15+
#include <hipsolver/hipsolver.h>
16+
#endif
1417
#else
1518
#include <hipblas.h>
1619
#include <hiprand.h>
1720
#include <hipsparse.h>
21+
#if GKO_HAVE_LAPACK
22+
#include <hipsolver.h>
23+
#endif
1824
#endif
1925

2026

@@ -107,4 +113,29 @@ std::string HipsparseError::get_error(int64 error_code)
107113
}
108114

109115

116+
std::string HipsolverError::get_error(int64 error_code)
117+
{
118+
#if GKO_HAVE_LAPACK
119+
#define GKO_REGISTER_HIPSOLVER_ERROR(error_name) \
120+
if (error_code == static_cast<int64>(error_name)) { \
121+
return #error_name; \
122+
}
123+
GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_SUCCESS);
124+
GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_NOT_INITIALIZED);
125+
GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_ALLOC_FAILED);
126+
GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_INVALID_VALUE);
127+
GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_ARCH_MISMATCH);
128+
GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_EXECUTION_FAILED);
129+
GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_INTERNAL_ERROR);
130+
GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED);
131+
GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_NOT_SUPPORTED);
132+
return "Unknown error";
133+
134+
#undef GKO_REGISTER_HIPSOLVER_ERROR
135+
#else
136+
return "Ginkgo must be built with LAPACK support to enable hipSOLVER";
137+
#endif
138+
}
139+
140+
110141
} // namespace gko

hip/base/executor.hip.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -14,6 +14,9 @@
1414
#include "common/cuda_hip/base/runtime.hpp"
1515
#include "hip/base/device.hpp"
1616
#include "hip/base/hipblas_handle.hpp"
17+
#if GKO_HAVE_LAPACK
18+
#include "hip/base/hipsolver_handle.hpp"
19+
#endif
1720
#include "hip/base/hipsparse_handle.hpp"
1821
#include "hip/base/scoped_device_id.hip.hpp"
1922

@@ -260,6 +263,14 @@ void HipExecutor::init_handles()
260263
detail::hip_scoped_device_id_guard g(id);
261264
kernels::hip::hipsparse::destroy_hipsparse_handle(handle);
262265
});
266+
#if GKO_HAVE_LAPACK
267+
this->hipsolver_handle_ = handle_manager<hipsolverDnContext>(
268+
kernels::hip::hipsolver::init(this->get_stream()),
269+
[id](hipsolverDnContext* handle) {
270+
detail::hip_scoped_device_id_guard g(id);
271+
kernels::hip::hipsolver::destroy_hipsolver_handle(handle);
272+
});
273+
#endif
263274
}
264275
}
265276

hip/base/hipblas_bindings.hip.hpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -128,6 +128,32 @@ GKO_BIND_HIPBLAS_GEAM(ValueType, detail::not_implemented);
128128
#undef GKO_BIND_HIPBLAS_GEAM
129129

130130

131+
#define GKO_BIND_HIPBLAS_TRMM(ValueType, HipblasName) \
132+
inline void trmm(hipblasHandle_t handle, hipblasSideMode_t side, \
133+
hipblasFillMode_t uplo, hipblasOperation_t trans, \
134+
hipblasDiagType_t diag, int m, int n, \
135+
const ValueType* alpha, const ValueType* a, int lda, \
136+
const ValueType* b, int ldb, ValueType* c, int ldc) \
137+
{ \
138+
GKO_ASSERT_NO_HIPBLAS_ERRORS( \
139+
HipblasName(handle, side, uplo, trans, diag, m, n, \
140+
as_hipblas_type(alpha), as_hipblas_type(a), lda, \
141+
as_hipblas_type(b), ldb, as_hipblas_type(c), ldc)); \
142+
} \
143+
static_assert(true, \
144+
"This assert is used to counter the false positive extra " \
145+
"semi-colon warnings")
146+
147+
GKO_BIND_HIPBLAS_TRMM(float, hipblasStrmm);
148+
GKO_BIND_HIPBLAS_TRMM(double, hipblasDtrmm);
149+
GKO_BIND_HIPBLAS_TRMM(std::complex<float>, hipblasCtrmm);
150+
GKO_BIND_HIPBLAS_TRMM(std::complex<double>, hipblasZtrmm);
151+
template <typename ValueType>
152+
GKO_BIND_HIPBLAS_TRMM(ValueType, detail::not_implemented);
153+
154+
#undef GKO_BIND_HIPBLAS_TRMM
155+
156+
131157
#define GKO_BIND_HIPBLAS_SCAL(ValueType, HipblasName) \
132158
inline void scal(hipblasHandle_t handle, int n, const ValueType* alpha, \
133159
ValueType* x, int incx) \
@@ -255,6 +281,12 @@ using namespace hipblas;
255281
#define BLAS_OP_T HIPBLAS_OP_T
256282
#define BLAS_OP_C HIPBLAS_OP_C
257283

284+
#define BLAS_SIDE_LEFT HIPBLAS_SIDE_LEFT
285+
#define BLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT
286+
287+
#define BLAS_DIAG_UNIT HIPBLAS_DIAG_UNIT
288+
#define BLAS_DIAG_NONUNIT HIPBLAS_DIAG_NON_UNIT
289+
258290

259291
} // namespace blas
260292
} // namespace hip

0 commit comments

Comments
 (0)