diff --git a/CMakeLists.txt b/CMakeLists.txt index f452002dc78..daaed4882e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -144,6 +144,7 @@ option( ) option(BUILD_SHARED_LIBS "Build shared (.so, .dylib, .dll) libraries" ON) option(GINKGO_BUILD_HWLOC "Build Ginkgo with HWLOC. Default is OFF." OFF) +option(GINKGO_BUILD_LAPACK "Build Ginkgo with LAPACK. Default is OFF." OFF) option( GINKGO_BUILD_PAPI_SDE "Build Ginkgo with PAPI SDE. Enabled if a system installation is found." @@ -359,6 +360,14 @@ if(GINKGO_BUILD_HWLOC AND (MSVC OR WIN32 OR CYGWIN OR APPLE)) ) endif() +set(GINKGO_HAVE_LAPACK 0) +if(GINKGO_BUILD_LAPACK) + # Need CMake 3.18+ to use LAPACK::LAPACK target + cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + find_package(LAPACK REQUIRED) + set(GINKGO_HAVE_LAPACK 1) +endif() + set(GINKGO_HAVE_GPU_AWARE_MPI OFF) set(GINKGO_HAVE_OPENMPI_PRE_4_1_X OFF) if(GINKGO_BUILD_MPI) diff --git a/cmake/get_info.cmake b/cmake/get_info.cmake index b0a690847ea..aa4ff803a36 100644 --- a/cmake/get_info.cmake +++ b/cmake/get_info.cmake @@ -160,7 +160,7 @@ foreach(log_type ${log_types}) ) ginkgo_print_module_footer(${${log_type}} " Enabled features:") ginkgo_print_foreach_variable(${${log_type}} - "GINKGO_MIXED_PRECISION;GINKGO_HAVE_GPU_AWARE_MPI;GINKGO_ENABLE_HALF" + "GINKGO_MIXED_PRECISION;GINKGO_HAVE_GPU_AWARE_MPI;GINKGO_ENABLE_HALF;GINKGO_HAVE_LAPACK" ) ginkgo_print_module_footer(${${log_type}} " Tests, benchmarks and examples:") ginkgo_print_foreach_variable(${${log_type}} diff --git a/cmake/hip.cmake b/cmake/hip.cmake index 3640bc276da..e3a4dd458f2 100644 --- a/cmake/hip.cmake +++ b/cmake/hip.cmake @@ -33,6 +33,9 @@ find_package(hipsparse REQUIRED) find_package(rocrand REQUIRED) find_package(rocthrust REQUIRED) find_package(ROCTX) +if(GINKGO_BUILD_LAPACK) + find_package(hipsolver REQUIRED) +endif() if(GINKGO_HIP_AMD_UNSAFE_ATOMIC AND GINKGO_HIP_VERSION VERSION_GREATER_EQUAL 5) set(CMAKE_HIP_FLAGS diff --git a/common/cuda_hip/CMakeLists.txt b/common/cuda_hip/CMakeLists.txt index 10f3b857d82..9ca9c86fdae 100644 --- a/common/cuda_hip/CMakeLists.txt +++ b/common/cuda_hip/CMakeLists.txt @@ -10,6 +10,7 @@ set(CUDA_HIP_SOURCES distributed/partition_helpers_kernels.cpp distributed/partition_kernels.cpp distributed/vector_kernels.cpp + eigensolver/lobpcg_kernels.cpp factorization/cholesky_kernels.cpp factorization/factorization_kernels.cpp factorization/elimination_forest_kernels.cpp diff --git a/common/cuda_hip/base/dev_lapack_bindings.hpp b/common/cuda_hip/base/dev_lapack_bindings.hpp new file mode 100644 index 00000000000..f214861d696 --- /dev/null +++ b/common/cuda_hip/base/dev_lapack_bindings.hpp @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_COMMON_CUDA_HIP_BASE_DEV_LAPACK_BINDINGS_HPP_ +#define GKO_COMMON_CUDA_HIP_BASE_DEV_LAPACK_BINDINGS_HPP_ + + +#if defined(GKO_COMPILING_CUDA) +#include "cuda/base/cusolver_bindings.hpp" +#define GKO_DEV_LAPACK_ERROR GKO_CUSOLVER_ERROR +#define DEV_LAPACK_INTERNAL_ERROR CUSOLVER_STATUS_INTERNAL_ERROR +#elif defined(GKO_COMPILING_HIP) +#include "hip/base/hipsolver_bindings.hip.hpp" +#define GKO_DEV_LAPACK_ERROR GKO_HIPSOLVER_ERROR +#define DEV_LAPACK_INTERNAL_ERROR HIPSOLVER_STATUS_INTERNAL_ERROR +#else +#error "Executor definition missing" +#endif + + +#endif // GKO_COMMON_CUDA_HIP_BASE_DEV_LAPACK_BINDINGS_HPP_ diff --git a/common/cuda_hip/eigensolver/lobpcg_kernels.cpp b/common/cuda_hip/eigensolver/lobpcg_kernels.cpp new file mode 100644 index 00000000000..d8e1ad61495 --- /dev/null +++ b/common/cuda_hip/eigensolver/lobpcg_kernels.cpp @@ -0,0 +1,328 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "core/eigensolver/lobpcg_kernels.hpp" + +#include + +#include +#include + +#include "common/cuda_hip/base/blas_bindings.hpp" +#include "common/cuda_hip/base/dev_lapack_bindings.hpp" +#include "common/cuda_hip/base/pointer_mode_guard.hpp" +#include "common/cuda_hip/components/thread_ids.hpp" + +#if GKO_HAVE_LAPACK + + +namespace gko { +namespace kernels { +namespace GKO_DEVICE_NAMESPACE { +namespace lobpcg { + + +constexpr int default_block_size = 512; + + +namespace kernel { + + +template +__global__ __launch_bounds__(default_block_size) void matrix_conj( + const int32 n, ValueType* a, const int32 a_stride) +{ + const auto tidx = thread::get_thread_id_flat(); + const auto row = tidx / n; + const auto col = tidx % n; + const ValueType zero = gko::zero(); + if (row < n && col < n) { + a[row * a_stride + col] = conj(a[row * a_stride + col]); + } +} + + +template +__global__ __launch_bounds__(default_block_size) void two_matrix_conj( + const int32 n, ValueType* a, const int32 a_stride, ValueType* b, + const int32 b_stride) +{ + const auto tidx = thread::get_thread_id_flat(); + const auto row = tidx / n; + const auto col = tidx % n; + const ValueType zero = gko::zero(); + if (row < n && col < n) { + a[row * a_stride + col] = conj(a[row * a_stride + col]); + b[row * b_stride + col] = conj(b[row * b_stride + col]); + } +} + + +template +__global__ __launch_bounds__(default_block_size) void fill_lower_col_major( + const int32 n, const ValueType* source, const int32 source_stride, + ValueType* dest, const int32 dest_stride) +{ + const auto tidx = thread::get_thread_id_flat(); + const auto row = tidx % n; + const auto col = tidx / n; + const ValueType zero = gko::zero(); + if (row < n && col < n) { + dest[col * dest_stride + row] = + (row >= col) ? source[col * source_stride + row] : zero; + } +} + + +} // namespace kernel + + +template +void symm_eig(std::shared_ptr exec, + matrix::Dense* a, + array>* e_vals, array* workspace) +{ + const auto id = exec->get_device_id(); + auto handle = exec->get_dev_lapack_handle(); + + constexpr auto max = std::numeric_limits::max(); + if (a->get_size()[1] > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + if (a->get_stride() > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + int32 n = static_cast(a->get_size()[0]); + int32 lda = static_cast(a->get_stride()); + // The dev_lapack routine expects column-major data, so we take the + // conjugate to perform A = A^T. + if constexpr (gko::is_complex_s::value) { + const auto grid_dim = ceildiv(n * n, default_block_size); + if (grid_dim > 0) { + kernel::matrix_conj<<get_stream()>>>( + n, as_device_type(a->get_values()), lda); + } + } + + int32 fp_buffer_num_elems; + dev_lapack::syevd_buffersize(handle, LAPACK_EIG_VECTOR, LAPACK_FILL_LOWER, + n, a->get_values(), lda, e_vals->get_data(), + &fp_buffer_num_elems); + size_type total_bytes = sizeof(ValueType) * fp_buffer_num_elems; + if (workspace->get_size() < total_bytes) { + workspace->resize_and_reset(total_bytes); + } + array dev_info(exec, 1); + try { + dev_lapack::syevd(handle, LAPACK_EIG_VECTOR, LAPACK_FILL_LOWER, n, + a->get_values(), lda, e_vals->get_data(), + reinterpret_cast(workspace->get_data()), + fp_buffer_num_elems, dev_info.get_data()); + + int32 host_info = exec->copy_val_to_host(dev_info.get_data()); + if (host_info != 0) { + throw GKO_DEV_LAPACK_ERROR(DEV_LAPACK_INTERNAL_ERROR); + } + } catch (std::exception& e) { + std::cout << e.what() << std::endl; + int32 host_info = exec->copy_val_to_host(dev_info.get_data()); + std::cout << "devInfo was " << host_info << std::endl; + throw; + } +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_LOBPCG_SYMM_EIG_KERNEL); + + +template +void symm_generalized_eig(std::shared_ptr exec, + matrix::Dense* a, + matrix::Dense* b, + array>* e_vals, + array* workspace) +{ + const auto id = exec->get_device_id(); + auto handle = exec->get_dev_lapack_handle(); + + constexpr auto max = std::numeric_limits::max(); + if (a->get_size()[1] > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + if (a->get_stride() > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + if (b->get_stride() > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + + int32 n = static_cast(a->get_size()[0]); + int32 lda = static_cast(a->get_stride()); + int32 ldb = static_cast(b->get_stride()); + // The dev_lapack routine expects column-major data, so we take the + // conjugate to perform A = A^T. + if constexpr (gko::is_complex_s::value) { + const auto grid_dim = ceildiv(n * n, default_block_size); + if (grid_dim > 0) { + kernel::two_matrix_conj<<get_stream()>>>( + n, as_device_type(a->get_values()), lda, + as_device_type(b->get_values()), ldb); + } + } + + int32 fp_buffer_num_elems; + dev_lapack::sygvd_buffersize(handle, LAPACK_EIG_TYPE_1, LAPACK_EIG_VECTOR, + LAPACK_FILL_LOWER, n, a->get_values(), lda, + b->get_values(), ldb, e_vals->get_data(), + &fp_buffer_num_elems); + size_type total_bytes = sizeof(ValueType) * fp_buffer_num_elems; + if (workspace->get_size() < total_bytes) { + workspace->resize_and_reset(total_bytes); + } + array dev_info(exec, 1); + try { + dev_lapack::sygvd(handle, LAPACK_EIG_TYPE_1, LAPACK_EIG_VECTOR, + LAPACK_FILL_LOWER, n, a->get_values(), lda, + b->get_values(), ldb, e_vals->get_data(), + reinterpret_cast(workspace->get_data()), + fp_buffer_num_elems, dev_info.get_data()); + + int32 host_info = exec->copy_val_to_host(dev_info.get_data()); + if (host_info != 0) { + throw GKO_DEV_LAPACK_ERROR(DEV_LAPACK_INTERNAL_ERROR); + } + } catch (std::exception& e) { + std::cout << e.what() << std::endl; + int32 host_info = exec->copy_val_to_host(dev_info.get_data()); + std::cout << "devInfo was " << host_info << std::endl; + throw; + } +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( + GKO_DECLARE_LOBPCG_SYMM_GENERALIZED_EIG_KERNEL); + + +template +void b_orthonormalize(std::shared_ptr exec, + matrix::Dense* a, LinOp* b, + array* workspace) +{ + const auto id = exec->get_device_id(); + auto handle = exec->get_dev_lapack_handle(); + + constexpr auto max = std::numeric_limits::max(); + if (a->get_size()[0] > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + if (a->get_stride() > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + const int32 lda = static_cast(a->get_stride()); + + // Compute A^H * B * A + auto b_a = matrix::Dense::create( + exec, gko::dim<2>{b->get_size()[0], a->get_size()[1]}); + b->apply(a, b_a); + auto aH_b_a = matrix::Dense::create( + exec, gko::dim<2>{a->get_size()[1], a->get_size()[1]}); + gko::as>(a->conj_transpose())->apply(b_a, aH_b_a); + + const int32 n = static_cast(aH_b_a->get_size()[0]); + const int32 ldaH_b_a = static_cast(aH_b_a->get_stride()); + + // Cholesky factorization + int32 fp_buffer_num_elems; + dev_lapack::potrf_buffersize(handle, LAPACK_FILL_LOWER, n, + aH_b_a->get_values(), ldaH_b_a, + &fp_buffer_num_elems); + size_type total_bytes = sizeof(ValueType) * fp_buffer_num_elems; + if (workspace->get_size() < total_bytes) { + workspace->resize_and_reset(total_bytes); + } + // LAPACK uses column-major, so using LOWER produces LL^H = A^T: + // L (col-major) is the complex conjugate of the lower factor for A. + array dev_info(exec, 1); + try { + dev_lapack::potrf(handle, LAPACK_FILL_LOWER, n, aH_b_a->get_values(), + ldaH_b_a, + reinterpret_cast(workspace->get_data()), + fp_buffer_num_elems, dev_info.get_data()); + + int32 host_info = exec->copy_val_to_host(dev_info.get_data()); + if (host_info != 0) { + throw GKO_DEV_LAPACK_ERROR(DEV_LAPACK_INTERNAL_ERROR); + } + } catch (std::exception& e) { + std::cout << e.what() << std::endl; + int32 host_info = exec->copy_val_to_host(dev_info.get_data()); + std::cout << "error in potrf: devInfo was " << host_info << std::endl; + throw; + } + + // Solve with L as the right hand side: LL^H X = L --> X = L^{-H}. + // Recall L is the complex conjugate of the "true" L, so really + // we will have (L_true)^{-T}, stored in column-major format, after potrs. + auto factor = matrix::Dense::create(exec, aH_b_a->get_size()); + const auto grid_dim = ceildiv(n * n, default_block_size); + if (grid_dim > 0) { + kernel::fill_lower_col_major<<get_stream()>>>( + n, as_device_type(aH_b_a->get_const_values()), ldaH_b_a, + as_device_type(factor->get_values()), ldaH_b_a); + } + try { + dev_lapack::potrs(handle, LAPACK_FILL_LOWER, n, factor->get_size()[1], + aH_b_a->get_values(), ldaH_b_a, factor->get_values(), + ldaH_b_a, dev_info.get_data()); + + int32 host_info = exec->copy_val_to_host(dev_info.get_data()); + if (host_info != 0) { + throw GKO_DEV_LAPACK_ERROR(DEV_LAPACK_INTERNAL_ERROR); + } + } catch (std::exception& e) { + std::cout << e.what() << std::endl; + int32 host_info = exec->copy_val_to_host(dev_info.get_data()); + std::cout << "error in potrs: devInfo was " << host_info << std::endl; + throw; + } + + // A = A * (L^{-1})^H + // A will be seen by BLAS as column-major, or A^T. The BLAS operation + // A^T_{ij} = F^H_{ik} A^T_{kj}, with F being the "factor" variable, + // is equivalent to A_{ji} = A_{jk} conj(F)_{ki} --> A = A * conj(F). + // Since F = (L_true)^{-T}, we have A = A * (L_true)^{-H} in row-major + // storage upon exit. + auto blas_handle = exec->get_blas_handle(); + blas::pointer_mode_guard pm_guard(blas_handle); + const ValueType alpha = gko::one(); + const int32 m = static_cast(a->get_size()[0]); + if constexpr (!gko::is_complex_s::value) { + blas::trmm(blas_handle, BLAS_SIDE_LEFT, LAPACK_FILL_UPPER, BLAS_OP_T, + BLAS_DIAG_NONUNIT, n, m, &alpha, factor->get_const_values(), + ldaH_b_a, a->get_values(), lda, a->get_values(), lda); + } else { + blas::trmm(blas_handle, BLAS_SIDE_LEFT, LAPACK_FILL_UPPER, BLAS_OP_C, + BLAS_DIAG_NONUNIT, n, m, &alpha, factor->get_const_values(), + ldaH_b_a, a->get_values(), lda, a->get_values(), lda); + } +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_LOBPCG_B_ORTHONORMALIZE_KERNEL); + + +} // namespace lobpcg +} // namespace GKO_DEVICE_NAMESPACE +} // namespace kernels +} // namespace gko + +#endif // GKO_HAVE_LAPACK diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index 1abe27e9558..a176767636e 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -25,6 +25,7 @@ #include "core/distributed/partition_helpers_kernels.hpp" #include "core/distributed/partition_kernels.hpp" #include "core/distributed/vector_kernels.hpp" +#include "core/eigensolver/lobpcg_kernels.hpp" #include "core/factorization/cholesky_kernels.hpp" #include "core/factorization/elimination_forest_kernels.hpp" #include "core/factorization/factorization_kernels.hpp" @@ -719,6 +720,19 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_MINRES_STEP_2_KERNEL); } // namespace minres +#if GKO_HAVE_LAPACK +namespace lobpcg { + + +GKO_STUB_VALUE_TYPE(GKO_DECLARE_LOBPCG_SYMM_EIG_KERNEL); +GKO_STUB_VALUE_TYPE(GKO_DECLARE_LOBPCG_SYMM_GENERALIZED_EIG_KERNEL); +GKO_STUB_VALUE_TYPE(GKO_DECLARE_LOBPCG_B_ORTHONORMALIZE_KERNEL); + + +} // namespace lobpcg +#endif + + namespace sparsity_csr { diff --git a/core/device_hooks/cuda_hooks.cpp b/core/device_hooks/cuda_hooks.cpp index 4124ac2bea5..d321356c9c4 100644 --- a/core/device_hooks/cuda_hooks.cpp +++ b/core/device_hooks/cuda_hooks.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -175,6 +175,12 @@ std::string CusparseError::get_error(int64) } +std::string CusolverError::get_error(int64) +{ + return "ginkgo CUDA module is not compiled"; +} + + std::string CufftError::get_error(int64) { return "ginkgo CUDA module is not compiled"; diff --git a/core/device_hooks/hip_hooks.cpp b/core/device_hooks/hip_hooks.cpp index 7f3497e8020..612732a9d50 100644 --- a/core/device_hooks/hip_hooks.cpp +++ b/core/device_hooks/hip_hooks.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -174,6 +174,12 @@ std::string HipsparseError::get_error(int64) } +std::string HipsolverError::get_error(int64) +{ + return "ginkgo HIP module is not compiled"; +} + + std::string HipfftError::get_error(int64) { return "ginkgo HIP module is not compiled"; diff --git a/core/eigensolver/lobpcg_kernels.hpp b/core/eigensolver/lobpcg_kernels.hpp new file mode 100644 index 00000000000..39bf549ec76 --- /dev/null +++ b/core/eigensolver/lobpcg_kernels.hpp @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_EIGENSOLVER_LOBPCG_KERNELS_HPP_ +#define GKO_CORE_EIGENSOLVER_LOBPCG_KERNELS_HPP_ + + +#include +#include +#include +#include + +#include "core/base/kernel_declaration.hpp" + + +namespace gko { +namespace kernels { +namespace lobpcg { + + +#define GKO_DECLARE_LOBPCG_SYMM_EIG_KERNEL(_type) \ + void symm_eig( \ + std::shared_ptr exec, matrix::Dense<_type>* a, \ + array>* e_vals, array* workspace) + + +#define GKO_DECLARE_LOBPCG_SYMM_GENERALIZED_EIG_KERNEL(_type) \ + void symm_generalized_eig( \ + std::shared_ptr exec, matrix::Dense<_type>* a, \ + matrix::Dense<_type>* b, array>* e_vals, \ + array* workspace) + + +#define GKO_DECLARE_LOBPCG_B_ORTHONORMALIZE_KERNEL(_type) \ + void b_orthonormalize(std::shared_ptr exec, \ + matrix::Dense<_type>* a, LinOp* b, \ + array* workspace) + + +#define GKO_DECLARE_ALL_AS_TEMPLATES \ + template \ + GKO_DECLARE_LOBPCG_SYMM_EIG_KERNEL(ValueType); \ + template \ + GKO_DECLARE_LOBPCG_SYMM_GENERALIZED_EIG_KERNEL(ValueType); \ + template \ + GKO_DECLARE_LOBPCG_B_ORTHONORMALIZE_KERNEL(ValueType) + + +} // namespace lobpcg + + +GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(lobpcg, GKO_DECLARE_ALL_AS_TEMPLATES); + + +#undef GKO_DECLARE_ALL_AS_TEMPLATES + + +} // namespace kernels +} // namespace gko + + +#endif // GKO_CORE_EIGENSOLVER_LOBPCG_KERNELS_HPP_ diff --git a/core/test/base/exception.cpp b/core/test/base/exception.cpp index ec5d4bf5763..2beef9b22f2 100644 --- a/core/test/base/exception.cpp +++ b/core/test/base/exception.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -77,6 +77,16 @@ TEST(ExceptionClasses, CusparseErrorReturnsCorrectWhatMessage) } +#if GKO_HAVE_LAPACK +TEST(ExceptionClasses, CusolverErrorReturnsCorrectWhatMessage) +{ + gko::CusolverError error("test_file.cpp", 123, "test_func", 1); + std::string expected = "test_file.cpp:123: test_func: "; + ASSERT_EQ(expected, std::string(error.what()).substr(0, expected.size())); +} +#endif + + TEST(ExceptionClasses, CufftErrorReturnsCorrectWhatMessage) { gko::CufftError error("test_file.cpp", 123, "test_func", 1); @@ -117,6 +127,16 @@ TEST(ExceptionClasses, HipsparseErrorReturnsCorrectWhatMessage) } +#if GKO_HAVE_LAPACK +TEST(ExceptionClasses, HipsolverErrorReturnsCorrectWhatMessage) +{ + gko::HipsolverError error("test_file.cpp", 123, "test_func", 1); + std::string expected = "test_file.cpp:123: test_func: "; + ASSERT_EQ(expected, std::string(error.what()).substr(0, expected.size())); +} +#endif + + TEST(ExceptionClasses, HipfftErrorReturnsCorrectWhatMessage) { gko::HipfftError error("test_file.cpp", 123, "test_func", 1); @@ -125,6 +145,16 @@ TEST(ExceptionClasses, HipfftErrorReturnsCorrectWhatMessage) } +#if GKO_HAVE_LAPACK +TEST(ExceptionClasses, LapackErrorReturnsCorrectWhatMessage) +{ + gko::LapackError error("test_file.cpp", 123, "test_func", 1); + std::string expected = "test_file.cpp:123: test_func: "; + ASSERT_EQ(expected, std::string(error.what()).substr(0, expected.size())); +} +#endif + + TEST(ExceptionClasses, DimensionMismatchReturnsCorrectWhatMessage) { gko::DimensionMismatch error("test_file.cpp", 243, "test_func", "a", 3, 4, diff --git a/core/test/base/exception_helpers.cpp b/core/test/base/exception_helpers.cpp index 50f81707ead..a57b4f9b2c8 100644 --- a/core/test/base/exception_helpers.cpp +++ b/core/test/base/exception_helpers.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -103,6 +103,16 @@ TEST(CudaError, ReturnsCusparseError) } +#if GKO_HAVE_LAPACK +void throws_cusolver_error() { throw GKO_CUSOLVER_ERROR(0); } + +TEST(CudaError, ReturnsCusolverError) +{ + ASSERT_THROW(throws_cusolver_error(), gko::CusolverError); +} +#endif + + void throws_cufft_error() { throw GKO_CUFFT_ERROR(0); } TEST(CudaError, ReturnsCufftError) @@ -143,6 +153,16 @@ TEST(HipError, ReturnsHipsparseError) } +#if GKO_HAVE_LAPACK +void throws_hipsolver_error() { throw GKO_HIPSOLVER_ERROR(0); } + +TEST(HipError, ReturnsHipsolverError) +{ + ASSERT_THROW(throws_hipsolver_error(), gko::HipsolverError); +} +#endif + + void throws_hipfft_error() { throw GKO_HIPFFT_ERROR(0); } TEST(HipError, ReturnsHipfftError) @@ -150,6 +170,15 @@ TEST(HipError, ReturnsHipfftError) ASSERT_THROW(throws_hipfft_error(), gko::HipfftError); } +#if GKO_HAVE_LAPACK +void throws_lapack_error() { throw GKO_LAPACK_ERROR(0); } + +TEST(LapackError, ReturnsLapackError) +{ + ASSERT_THROW(throws_lapack_error(), gko::LapackError); +} +#endif + TEST(AssertIsSquareMatrix, DoesNotThrowWhenIsSquareMatrix) { diff --git a/cuda/CMakeLists.txt b/cuda/CMakeLists.txt index 8394b3aefde..56fe0496d6b 100644 --- a/cuda/CMakeLists.txt +++ b/cuda/CMakeLists.txt @@ -131,6 +131,9 @@ target_link_libraries( CUDA::cufft nvtx::nvtx ) +if(GINKGO_BUILD_LAPACK) + target_link_libraries(ginkgo_cuda PRIVATE CUDA::cusolver) +endif() # NVTX3 is header-only and requires dlopen/dlclose in static builds target_link_libraries(ginkgo_cuda PUBLIC ginkgo_device ${CMAKE_DL_LIBS}) diff --git a/cuda/base/cublas_bindings.hpp b/cuda/base/cublas_bindings.hpp index 9a8b4070b03..01fea136d0b 100644 --- a/cuda/base/cublas_bindings.hpp +++ b/cuda/base/cublas_bindings.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -118,6 +118,32 @@ GKO_BIND_CUBLAS_GEAM(ValueType, detail::not_implemented); #undef GKO_BIND_CUBLAS_GEAM +#define GKO_BIND_CUBLAS_TRMM(ValueType, CublasName) \ + inline void trmm(cublasHandle_t handle, cublasSideMode_t side, \ + cublasFillMode_t uplo, cublasOperation_t trans, \ + cublasDiagType_t diag, int m, int n, \ + const ValueType* alpha, const ValueType* a, int lda, \ + const ValueType* b, int ldb, ValueType* c, int ldc) \ + { \ + GKO_ASSERT_NO_CUBLAS_ERRORS( \ + CublasName(handle, side, uplo, trans, diag, m, n, \ + as_culibs_type(alpha), as_culibs_type(a), lda, \ + as_culibs_type(b), ldb, as_culibs_type(c), ldc)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUBLAS_TRMM(float, cublasStrmm); +GKO_BIND_CUBLAS_TRMM(double, cublasDtrmm); +GKO_BIND_CUBLAS_TRMM(std::complex, cublasCtrmm); +GKO_BIND_CUBLAS_TRMM(std::complex, cublasZtrmm); +template +GKO_BIND_CUBLAS_TRMM(ValueType, detail::not_implemented); + +#undef GKO_BIND_CUBLAS_TRMM + + #define GKO_BIND_CUBLAS_SCAL(ValueType, CublasName) \ inline void scal(cublasHandle_t handle, int n, const ValueType* alpha, \ ValueType* x, int incx) \ @@ -241,6 +267,12 @@ using namespace cublas; #define BLAS_OP_T CUBLAS_OP_T #define BLAS_OP_C CUBLAS_OP_C +#define BLAS_SIDE_LEFT CUBLAS_SIDE_LEFT +#define BLAS_SIDE_RIGHT CUBLAS_SIDE_RIGHT + +#define BLAS_DIAG_UNIT CUBLAS_DIAG_UNIT +#define BLAS_DIAG_NONUNIT CUBLAS_DIAG_NON_UNIT + } // namespace blas } // namespace cuda diff --git a/cuda/base/cusolver_bindings.hpp b/cuda/base/cusolver_bindings.hpp new file mode 100644 index 00000000000..f151d944da8 --- /dev/null +++ b/cuda/base/cusolver_bindings.hpp @@ -0,0 +1,259 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CUDA_BASE_CUSOLVER_BINDINGS_HPP_ +#define GKO_CUDA_BASE_CUSOLVER_BINDINGS_HPP_ + + +#include +#include + +#include + +#include "common/cuda_hip/base/types.hpp" + + +namespace gko { +namespace kernels { +namespace cuda { +/** + * @brief The CUSOLVER namespace. + * + * @ingroup cusolver + */ +namespace cusolver { +/** + * @brief The detail namespace. + * + * @ingroup detail + */ +namespace detail { + + +template +inline int64 not_implemented(Args...) +{ + return static_cast(CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED); +} + + +} // namespace detail + + +template +struct is_supported : std::false_type {}; + +template <> +struct is_supported : std::true_type {}; + +template <> +struct is_supported : std::true_type {}; + +template <> +struct is_supported> : std::true_type {}; + +template <> +struct is_supported> : std::true_type {}; + + +#define GKO_BIND_CUSOLVER_SYEVD_BUFFERSIZE(ValueType, CusolverName) \ + inline void syevd_buffersize( \ + cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ + cublasFillMode_t uplo, int32 n, ValueType* a, int32 lda, \ + remove_complex* w, int32* buffer_num_elems) \ + { \ + GKO_ASSERT_NO_CUSOLVER_ERRORS( \ + CusolverName(handle, jobz, uplo, n, as_culibs_type(a), lda, \ + as_culibs_type(w), buffer_num_elems)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSOLVER_SYEVD_BUFFERSIZE(float, cusolverDnSsyevd_bufferSize); +GKO_BIND_CUSOLVER_SYEVD_BUFFERSIZE(double, cusolverDnDsyevd_bufferSize); +GKO_BIND_CUSOLVER_SYEVD_BUFFERSIZE(std::complex, + cusolverDnCheevd_bufferSize); +GKO_BIND_CUSOLVER_SYEVD_BUFFERSIZE(std::complex, + cusolverDnZheevd_bufferSize); +template +GKO_BIND_CUSOLVER_SYEVD_BUFFERSIZE(ValueType, detail::not_implemented); + +#undef GKO_BIND_CUSOLVER_SYEVD_BUFFERSIZE + + +#define GKO_BIND_CUSOLVER_SYEVD(ValueType, CusolverName) \ + inline void syevd(cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ + cublasFillMode_t uplo, int32 n, ValueType* a, int32 lda, \ + remove_complex* w, ValueType* work, \ + int32 buffer_num_elems, int32* dev_info) \ + { \ + GKO_ASSERT_NO_CUSOLVER_ERRORS(CusolverName( \ + handle, jobz, uplo, n, as_culibs_type(a), lda, as_culibs_type(w), \ + as_culibs_type(work), buffer_num_elems, dev_info)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSOLVER_SYEVD(float, cusolverDnSsyevd); +GKO_BIND_CUSOLVER_SYEVD(double, cusolverDnDsyevd); +GKO_BIND_CUSOLVER_SYEVD(std::complex, cusolverDnCheevd); +GKO_BIND_CUSOLVER_SYEVD(std::complex, cusolverDnZheevd); +template +GKO_BIND_CUSOLVER_SYEVD(ValueType, detail::not_implemented); + +#undef GKO_BIND_CUSOLVER_SYEVD + + +#define GKO_BIND_CUSOLVER_SYGVD_BUFFERSIZE(ValueType, CusolverName) \ + inline void sygvd_buffersize( \ + cusolverDnHandle_t handle, cusolverEigType_t itype, \ + cusolverEigMode_t jobz, cublasFillMode_t uplo, int32 n, ValueType* a, \ + int32 lda, ValueType* b, int32 ldb, remove_complex* w, \ + int32* buffer_num_elems) \ + { \ + GKO_ASSERT_NO_CUSOLVER_ERRORS(CusolverName( \ + handle, itype, jobz, uplo, n, as_culibs_type(a), lda, \ + as_culibs_type(b), ldb, as_culibs_type(w), buffer_num_elems)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSOLVER_SYGVD_BUFFERSIZE(float, cusolverDnSsygvd_bufferSize); +GKO_BIND_CUSOLVER_SYGVD_BUFFERSIZE(double, cusolverDnDsygvd_bufferSize); +GKO_BIND_CUSOLVER_SYGVD_BUFFERSIZE(std::complex, + cusolverDnChegvd_bufferSize); +GKO_BIND_CUSOLVER_SYGVD_BUFFERSIZE(std::complex, + cusolverDnZhegvd_bufferSize); +template +GKO_BIND_CUSOLVER_SYGVD_BUFFERSIZE(ValueType, detail::not_implemented); + +#undef GKO_BIND_CUSOLVER_SYGVD_BUFFERSIZE + + +#define GKO_BIND_CUSOLVER_SYGVD(ValueType, CusolverName) \ + inline void sygvd(cusolverDnHandle_t handle, cusolverEigType_t itype, \ + cusolverEigMode_t jobz, cublasFillMode_t uplo, int32 n, \ + ValueType* a, int32 lda, ValueType* b, int32 ldb, \ + remove_complex* w, ValueType* work, \ + int32 buffer_num_elems, int32* dev_info) \ + { \ + GKO_ASSERT_NO_CUSOLVER_ERRORS( \ + CusolverName(handle, itype, jobz, uplo, n, as_culibs_type(a), lda, \ + as_culibs_type(b), ldb, as_culibs_type(w), \ + as_culibs_type(work), buffer_num_elems, dev_info)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSOLVER_SYGVD(float, cusolverDnSsygvd); +GKO_BIND_CUSOLVER_SYGVD(double, cusolverDnDsygvd); +GKO_BIND_CUSOLVER_SYGVD(std::complex, cusolverDnChegvd); +GKO_BIND_CUSOLVER_SYGVD(std::complex, cusolverDnZhegvd); +template +GKO_BIND_CUSOLVER_SYGVD(ValueType, detail::not_implemented); + +#undef GKO_BIND_CUSOLVER_SYGVD + + +#define GKO_BIND_CUSOLVER_POTRF_BUFFERSIZE(ValueType, CusolverName) \ + inline void potrf_buffersize(cusolverDnHandle_t handle, \ + cublasFillMode_t uplo, int32 n, ValueType* a, \ + int32 lda, int32* buffer_num_elems) \ + { \ + GKO_ASSERT_NO_CUSOLVER_ERRORS(CusolverName( \ + handle, uplo, n, as_culibs_type(a), lda, buffer_num_elems)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSOLVER_POTRF_BUFFERSIZE(float, cusolverDnSpotrf_bufferSize); +GKO_BIND_CUSOLVER_POTRF_BUFFERSIZE(double, cusolverDnDpotrf_bufferSize); +GKO_BIND_CUSOLVER_POTRF_BUFFERSIZE(std::complex, + cusolverDnCpotrf_bufferSize); +GKO_BIND_CUSOLVER_POTRF_BUFFERSIZE(std::complex, + cusolverDnZpotrf_bufferSize); +template +GKO_BIND_CUSOLVER_POTRF_BUFFERSIZE(ValueType, detail::not_implemented); + +#undef GKO_BIND_CUSOLVER_POTRF_BUFFERSIZE + + +#define GKO_BIND_CUSOLVER_POTRF(ValueType, CusolverName) \ + inline void potrf(cusolverDnHandle_t handle, cublasFillMode_t uplo, \ + int32 n, ValueType* a, int32 lda, ValueType* work, \ + int32 buffer_num_elems, int32* dev_info) \ + { \ + GKO_ASSERT_NO_CUSOLVER_ERRORS( \ + CusolverName(handle, uplo, n, as_culibs_type(a), lda, \ + as_culibs_type(work), buffer_num_elems, dev_info)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSOLVER_POTRF(float, cusolverDnSpotrf); +GKO_BIND_CUSOLVER_POTRF(double, cusolverDnDpotrf); +GKO_BIND_CUSOLVER_POTRF(std::complex, cusolverDnCpotrf); +GKO_BIND_CUSOLVER_POTRF(std::complex, cusolverDnZpotrf); +template +GKO_BIND_CUSOLVER_POTRF(ValueType, detail::not_implemented); + +#undef GKO_BIND_CUSOLVER_POTRF + + +#define GKO_BIND_CUSOLVER_POTRS(ValueType, CusolverName) \ + inline void potrs(cusolverDnHandle_t handle, cublasFillMode_t uplo, \ + int32 n, int32 nrhs, ValueType* a, int32 lda, \ + ValueType* b, int32 ldb, int32* dev_info) \ + { \ + GKO_ASSERT_NO_CUSOLVER_ERRORS( \ + CusolverName(handle, uplo, n, nrhs, as_culibs_type(a), lda, \ + as_culibs_type(b), ldb, dev_info)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSOLVER_POTRS(float, cusolverDnSpotrs); +GKO_BIND_CUSOLVER_POTRS(double, cusolverDnDpotrs); +GKO_BIND_CUSOLVER_POTRS(std::complex, cusolverDnCpotrs); +GKO_BIND_CUSOLVER_POTRS(std::complex, cusolverDnZpotrs); +template +GKO_BIND_CUSOLVER_POTRS(ValueType, detail::not_implemented); + +#undef GKO_BIND_CUSOLVER_POTRS + + +} // namespace cusolver + + +namespace dev_lapack { + + +using namespace cusolver; + + +#define LAPACK_EIG_TYPE_1 CUSOLVER_EIG_TYPE_1 +#define LAPACK_EIG_TYPE_2 CUSOLVER_EIG_TYPE_2 +#define LAPACK_EIG_TYPE_3 CUSOLVER_EIG_TYPE_3 + +#define LAPACK_EIG_VECTOR CUSOLVER_EIG_MODE_VECTOR +#define LAPACK_EIG_NOVECTOR CUSOLVER_EIG_MODE_NOVECTOR + +#define LAPACK_FILL_UPPER CUBLAS_FILL_MODE_UPPER +#define LAPACK_FILL_LOWER CUBLAS_FILL_MODE_LOWER + + +} // namespace dev_lapack +} // namespace cuda +} // namespace kernels +} // namespace gko + + +#endif // GKO_CUDA_BASE_CUSOLVER_BINDINGS_HPP_ diff --git a/cuda/base/cusolver_handle.hpp b/cuda/base/cusolver_handle.hpp new file mode 100644 index 00000000000..4ec308d557c --- /dev/null +++ b/cuda/base/cusolver_handle.hpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CUDA_BASE_CUSOLVER_HANDLE_HPP_ +#define GKO_CUDA_BASE_CUSOLVER_HANDLE_HPP_ + +#include +#include + +#include + + +namespace gko { +namespace kernels { +namespace cuda { +/** + * @brief The CUSOLVER namespace. + * + * @ingroup cusolver + */ +namespace cusolver { + + +inline cusolverDnHandle_t init(cudaStream_t stream) +{ + cusolverDnHandle_t handle; + GKO_ASSERT_NO_CUSOLVER_ERRORS(cusolverDnCreate(&handle)); + GKO_ASSERT_NO_CUSOLVER_ERRORS(cusolverDnSetStream(handle, stream)); + return handle; +} + + +inline void destroy(cusolverDnHandle_t handle) +{ + GKO_ASSERT_NO_CUSOLVER_ERRORS(cusolverDnDestroy(handle)); +} + + +} // namespace cusolver +} // namespace cuda +} // namespace kernels +} // namespace gko + + +#endif // GKO_CUDA_BASE_CUSOLVER_HANDLE_HPP_ diff --git a/cuda/base/exception.cpp b/cuda/base/exception.cpp index 7bb7fae5bd5..5acae01ebe3 100644 --- a/cuda/base/exception.cpp +++ b/cuda/base/exception.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -10,6 +10,9 @@ #include #include #include +#if GKO_HAVE_LAPACK +#include +#endif #include #include @@ -98,6 +101,31 @@ std::string CusparseError::get_error(int64 error_code) } +std::string CusolverError::get_error(int64 error_code) +{ +#if GKO_HAVE_LAPACK +#define GKO_REGISTER_CUSOLVER_ERROR(error_name) \ + if (error_code == static_cast(error_name)) { \ + return #error_name; \ + } + GKO_REGISTER_CUSOLVER_ERROR(CUSOLVER_STATUS_SUCCESS); + GKO_REGISTER_CUSOLVER_ERROR(CUSOLVER_STATUS_NOT_INITIALIZED); + GKO_REGISTER_CUSOLVER_ERROR(CUSOLVER_STATUS_ALLOC_FAILED); + GKO_REGISTER_CUSOLVER_ERROR(CUSOLVER_STATUS_INVALID_VALUE); + GKO_REGISTER_CUSOLVER_ERROR(CUSOLVER_STATUS_ARCH_MISMATCH); + GKO_REGISTER_CUSOLVER_ERROR(CUSOLVER_STATUS_EXECUTION_FAILED); + GKO_REGISTER_CUSOLVER_ERROR(CUSOLVER_STATUS_INTERNAL_ERROR); + GKO_REGISTER_CUSOLVER_ERROR(CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED); + GKO_REGISTER_CUSOLVER_ERROR(CUSOLVER_STATUS_NOT_SUPPORTED); + return "Unknown error"; + +#undef GKO_REGISTER_CUSOLVER_ERROR +#else + return "Ginkgo must be built with LAPACK support to enable cuSOLVER"; +#endif +} + + std::string CufftError::get_error(int64 error_code) { #define GKO_REGISTER_CUFFT_ERROR(error_name) \ diff --git a/cuda/base/executor.cpp b/cuda/base/executor.cpp index 8380eddcf1b..4c25fd90716 100644 --- a/cuda/base/executor.cpp +++ b/cuda/base/executor.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -20,6 +20,9 @@ #include "common/cuda_hip/base/config.hpp" #include "common/cuda_hip/base/executor.hpp.inc" #include "cuda/base/cublas_handle.hpp" +#if GKO_HAVE_LAPACK +#include "cuda/base/cusolver_handle.hpp" +#endif #include "cuda/base/cusparse_handle.hpp" #include "cuda/base/device.hpp" #include "cuda/base/scoped_device_id.hpp" @@ -256,6 +259,14 @@ void CudaExecutor::init_handles() detail::cuda_scoped_device_id_guard g(id); kernels::cuda::cusparse::destroy(handle); }); +#if GKO_HAVE_LAPACK + this->cusolver_handle_ = handle_manager( + kernels::cuda::cusolver::init(this->get_stream()), + [id](cusolverDnHandle_t handle) { + detail::cuda_scoped_device_id_guard g(id); + kernels::cuda::cusolver::destroy(handle); + }); +#endif } } diff --git a/cuda/test/base/exception_helpers.cu b/cuda/test/base/exception_helpers.cu index 7ee7ca0e8f0..9c78c5320b2 100644 --- a/cuda/test/base/exception_helpers.cu +++ b/cuda/test/base/exception_helpers.cu @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -12,6 +12,10 @@ #include +#if GKO_HAVE_LAPACK // Must be after Ginkgo include for GKO_HAVE_LAPACK def +#include +#endif + namespace { @@ -64,6 +68,20 @@ TEST(AssertNoCusparseErrors, DoesNotThrowOnSuccess) } +#if GKO_HAVE_LAPACK +TEST(AssertNoCusolverErrors, ThrowsOnError) +{ + ASSERT_THROW(GKO_ASSERT_NO_CUSOLVER_ERRORS(1), gko::CusolverError); +} + + +TEST(AssertNoCusolverErrors, DoesNotThrowOnSuccess) +{ + ASSERT_NO_THROW(GKO_ASSERT_NO_CUSOLVER_ERRORS(CUSOLVER_STATUS_SUCCESS)); +} +#endif + + TEST(AssertNoCufftErrors, ThrowsOnError) { ASSERT_THROW(GKO_ASSERT_NO_CUFFT_ERRORS(1), gko::CufftError); diff --git a/hip/CMakeLists.txt b/hip/CMakeLists.txt index 605cf6db123..3de5e50c866 100644 --- a/hip/CMakeLists.txt +++ b/hip/CMakeLists.txt @@ -120,6 +120,9 @@ endif() if(GINKGO_HAVE_ROCTX) target_link_libraries(ginkgo_hip PRIVATE roc::roctx) endif() +if(GINKGO_BUILD_LAPACK) + target_link_libraries(ginkgo_hip PRIVATE roc::hipsolver) +endif() target_compile_options( ginkgo_hip diff --git a/hip/base/exception.hip.cpp b/hip/base/exception.hip.cpp index c83778951d0..6d4da859a0d 100644 --- a/hip/base/exception.hip.cpp +++ b/hip/base/exception.hip.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -11,10 +11,16 @@ #include #include #include +#if GKO_HAVE_LAPACK +#include +#endif #else #include #include #include +#if GKO_HAVE_LAPACK +#include +#endif #endif @@ -107,4 +113,29 @@ std::string HipsparseError::get_error(int64 error_code) } +std::string HipsolverError::get_error(int64 error_code) +{ +#if GKO_HAVE_LAPACK +#define GKO_REGISTER_HIPSOLVER_ERROR(error_name) \ + if (error_code == static_cast(error_name)) { \ + return #error_name; \ + } + GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_SUCCESS); + GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_NOT_INITIALIZED); + GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_ALLOC_FAILED); + GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_INVALID_VALUE); + GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_ARCH_MISMATCH); + GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_EXECUTION_FAILED); + GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_INTERNAL_ERROR); + GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED); + GKO_REGISTER_HIPSOLVER_ERROR(HIPSOLVER_STATUS_NOT_SUPPORTED); + return "Unknown error"; + +#undef GKO_REGISTER_HIPSOLVER_ERROR +#else + return "Ginkgo must be built with LAPACK support to enable hipSOLVER"; +#endif +} + + } // namespace gko diff --git a/hip/base/executor.hip.cpp b/hip/base/executor.hip.cpp index 769d650d984..62f4b31eef5 100644 --- a/hip/base/executor.hip.cpp +++ b/hip/base/executor.hip.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -14,6 +14,9 @@ #include "common/cuda_hip/base/runtime.hpp" #include "hip/base/device.hpp" #include "hip/base/hipblas_handle.hpp" +#if GKO_HAVE_LAPACK +#include "hip/base/hipsolver_handle.hpp" +#endif #include "hip/base/hipsparse_handle.hpp" #include "hip/base/scoped_device_id.hip.hpp" @@ -260,6 +263,14 @@ void HipExecutor::init_handles() detail::hip_scoped_device_id_guard g(id); kernels::hip::hipsparse::destroy_hipsparse_handle(handle); }); +#if GKO_HAVE_LAPACK + this->hipsolver_handle_ = handle_manager( + kernels::hip::hipsolver::init(this->get_stream()), + [id](hipsolverDnContext* handle) { + detail::hip_scoped_device_id_guard g(id); + kernels::hip::hipsolver::destroy_hipsolver_handle(handle); + }); +#endif } } diff --git a/hip/base/hipblas_bindings.hip.hpp b/hip/base/hipblas_bindings.hip.hpp index 04c1610c0cc..4be0dd13651 100644 --- a/hip/base/hipblas_bindings.hip.hpp +++ b/hip/base/hipblas_bindings.hip.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -119,13 +119,41 @@ GKO_BIND_HIPBLAS_GEMM(ValueType, detail::not_implemented); GKO_BIND_HIPBLAS_GEAM(float, hipblasSgeam); GKO_BIND_HIPBLAS_GEAM(double, hipblasDgeam); -// Hipblas does not provide geam complex version yet. +GKO_BIND_HIPBLAS_GEAM(std::complex, hipblasCgeam); +GKO_BIND_HIPBLAS_GEAM(std::complex, hipblasZgeam); + template GKO_BIND_HIPBLAS_GEAM(ValueType, detail::not_implemented); #undef GKO_BIND_HIPBLAS_GEAM +#define GKO_BIND_HIPBLAS_TRMM(ValueType, HipblasName) \ + inline void trmm(hipblasHandle_t handle, hipblasSideMode_t side, \ + hipblasFillMode_t uplo, hipblasOperation_t trans, \ + hipblasDiagType_t diag, int m, int n, \ + const ValueType* alpha, const ValueType* a, int lda, \ + const ValueType* b, int ldb, ValueType* c, int ldc) \ + { \ + GKO_ASSERT_NO_HIPBLAS_ERRORS( \ + HipblasName(handle, side, uplo, trans, diag, m, n, \ + as_hipblas_type(alpha), as_hipblas_type(a), lda, \ + as_hipblas_type(b), ldb, as_hipblas_type(c), ldc)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HIPBLAS_TRMM(float, hipblasStrmm); +GKO_BIND_HIPBLAS_TRMM(double, hipblasDtrmm); +GKO_BIND_HIPBLAS_TRMM(std::complex, hipblasCtrmm); +GKO_BIND_HIPBLAS_TRMM(std::complex, hipblasZtrmm); +template +GKO_BIND_HIPBLAS_TRMM(ValueType, detail::not_implemented); + +#undef GKO_BIND_HIPBLAS_TRMM + + #define GKO_BIND_HIPBLAS_SCAL(ValueType, HipblasName) \ inline void scal(hipblasHandle_t handle, int n, const ValueType* alpha, \ ValueType* x, int incx) \ @@ -253,6 +281,12 @@ using namespace hipblas; #define BLAS_OP_T HIPBLAS_OP_T #define BLAS_OP_C HIPBLAS_OP_C +#define BLAS_SIDE_LEFT HIPBLAS_SIDE_LEFT +#define BLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT + +#define BLAS_DIAG_UNIT HIPBLAS_DIAG_UNIT +#define BLAS_DIAG_NONUNIT HIPBLAS_DIAG_NON_UNIT + } // namespace blas } // namespace hip diff --git a/hip/base/hipsolver_bindings.hip.hpp b/hip/base/hipsolver_bindings.hip.hpp new file mode 100644 index 00000000000..54ee63f64da --- /dev/null +++ b/hip/base/hipsolver_bindings.hip.hpp @@ -0,0 +1,266 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_HIP_BASE_HIPSOLVER_BINDINGS_HPP_ +#define GKO_HIP_BASE_HIPSOLVER_BINDINGS_HPP_ + + +#if HIP_VERSION >= 50200000 +#include +#else +#include +#endif + + +#include + +#include "common/cuda_hip/base/types.hpp" + + +namespace gko { +namespace kernels { +namespace hip { +/** + * @brief The HIPSOLVER namespace. + * + * @ingroup hipsolver + */ +namespace hipsolver { +/** + * @brief The detail namespace. + * + * @ingroup detail + */ +namespace detail { + + +template +inline int64 not_implemented(Args...) +{ + return static_cast(HIPSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED); +} + + +} // namespace detail + + +template +struct is_supported : std::false_type {}; + +template <> +struct is_supported : std::true_type {}; + +template <> +struct is_supported : std::true_type {}; + +template <> +struct is_supported> : std::true_type {}; + +template <> +struct is_supported> : std::true_type {}; + + +#define GKO_BIND_HIPSOLVER_SYEVD_BUFFERSIZE(ValueType, HipsolverName) \ + inline void syevd_buffersize( \ + hipsolverDnHandle_t handle, hipsolverEigMode_t jobz, \ + hipsolverFillMode_t uplo, int32 n, ValueType* a, int32 lda, \ + remove_complex* w, int32* buffer_num_elems) \ + { \ + GKO_ASSERT_NO_HIPSOLVER_ERRORS( \ + HipsolverName(handle, jobz, uplo, n, as_hiplibs_type(a), lda, \ + as_hiplibs_type(w), buffer_num_elems)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HIPSOLVER_SYEVD_BUFFERSIZE(float, hipsolverDnSsyevd_bufferSize); +GKO_BIND_HIPSOLVER_SYEVD_BUFFERSIZE(double, hipsolverDnDsyevd_bufferSize); +GKO_BIND_HIPSOLVER_SYEVD_BUFFERSIZE(std::complex, + hipsolverDnCheevd_bufferSize); +GKO_BIND_HIPSOLVER_SYEVD_BUFFERSIZE(std::complex, + hipsolverDnZheevd_bufferSize); +template +GKO_BIND_HIPSOLVER_SYEVD_BUFFERSIZE(ValueType, detail::not_implemented); + +#undef GKO_BIND_HIPSOLVER_SYEVD_BUFFERSIZE + + +#define GKO_BIND_HIPSOLVER_SYEVD(ValueType, HipsolverName) \ + inline void syevd(hipsolverDnHandle_t handle, hipsolverEigMode_t jobz, \ + hipsolverFillMode_t uplo, int32 n, ValueType* a, \ + int32 lda, remove_complex* w, \ + ValueType* work, int32 buffer_num_elems, \ + int32* dev_info) \ + { \ + GKO_ASSERT_NO_HIPSOLVER_ERRORS( \ + HipsolverName(handle, jobz, uplo, n, as_hiplibs_type(a), lda, \ + as_hiplibs_type(w), as_hiplibs_type(work), \ + buffer_num_elems, dev_info)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HIPSOLVER_SYEVD(float, hipsolverDnSsyevd); +GKO_BIND_HIPSOLVER_SYEVD(double, hipsolverDnDsyevd); +GKO_BIND_HIPSOLVER_SYEVD(std::complex, hipsolverDnCheevd); +GKO_BIND_HIPSOLVER_SYEVD(std::complex, hipsolverDnZheevd); +template +GKO_BIND_HIPSOLVER_SYEVD(ValueType, detail::not_implemented); + +#undef GKO_BIND_HIPSOLVER_SYEVD + + +#define GKO_BIND_HIPSOLVER_SYGVD_BUFFERSIZE(ValueType, HipsolverName) \ + inline void sygvd_buffersize( \ + hipsolverDnHandle_t handle, hipsolverEigType_t itype, \ + hipsolverEigMode_t jobz, hipsolverFillMode_t uplo, int32 n, \ + ValueType* a, int32 lda, ValueType* b, int32 ldb, \ + remove_complex* w, int32* buffer_num_elems) \ + { \ + GKO_ASSERT_NO_HIPSOLVER_ERRORS(HipsolverName( \ + handle, itype, jobz, uplo, n, as_hiplibs_type(a), lda, \ + as_hiplibs_type(b), ldb, as_hiplibs_type(w), buffer_num_elems)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HIPSOLVER_SYGVD_BUFFERSIZE(float, hipsolverDnSsygvd_bufferSize); +GKO_BIND_HIPSOLVER_SYGVD_BUFFERSIZE(double, hipsolverDnDsygvd_bufferSize); +GKO_BIND_HIPSOLVER_SYGVD_BUFFERSIZE(std::complex, + hipsolverDnChegvd_bufferSize); +GKO_BIND_HIPSOLVER_SYGVD_BUFFERSIZE(std::complex, + hipsolverDnZhegvd_bufferSize); +template +GKO_BIND_HIPSOLVER_SYGVD_BUFFERSIZE(ValueType, detail::not_implemented); + +#undef GKO_BIND_HIPSOLVER_SYGVD_BUFFERSIZE + + +#define GKO_BIND_HIPSOLVER_SYGVD(ValueType, HipsolverName) \ + inline void sygvd(hipsolverDnHandle_t handle, hipsolverEigType_t itype, \ + hipsolverEigMode_t jobz, hipsolverFillMode_t uplo, \ + int32 n, ValueType* a, int32 lda, ValueType* b, \ + int32 ldb, remove_complex* w, \ + ValueType* work, int32 buffer_num_elems, \ + int32* dev_info) \ + { \ + GKO_ASSERT_NO_HIPSOLVER_ERRORS( \ + HipsolverName(handle, itype, jobz, uplo, n, as_hiplibs_type(a), \ + lda, as_hiplibs_type(b), ldb, as_hiplibs_type(w), \ + as_hiplibs_type(work), buffer_num_elems, dev_info)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HIPSOLVER_SYGVD(float, hipsolverDnSsygvd); +GKO_BIND_HIPSOLVER_SYGVD(double, hipsolverDnDsygvd); +GKO_BIND_HIPSOLVER_SYGVD(std::complex, hipsolverDnChegvd); +GKO_BIND_HIPSOLVER_SYGVD(std::complex, hipsolverDnZhegvd); +template +GKO_BIND_HIPSOLVER_SYGVD(ValueType, detail::not_implemented); + +#undef GKO_BIND_HIPSOLVER_SYGVD + + +#define GKO_BIND_HIPSOLVER_POTRF_BUFFERSIZE(ValueType, HipsolverName) \ + inline void potrf_buffersize( \ + hipsolverDnHandle_t handle, hipsolverFillMode_t uplo, int32 n, \ + ValueType* a, int32 lda, int32* buffer_num_elems) \ + { \ + GKO_ASSERT_NO_HIPSOLVER_ERRORS(HipsolverName( \ + handle, uplo, n, as_hiplibs_type(a), lda, buffer_num_elems)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HIPSOLVER_POTRF_BUFFERSIZE(float, hipsolverDnSpotrf_bufferSize); +GKO_BIND_HIPSOLVER_POTRF_BUFFERSIZE(double, hipsolverDnDpotrf_bufferSize); +GKO_BIND_HIPSOLVER_POTRF_BUFFERSIZE(std::complex, + hipsolverDnCpotrf_bufferSize); +GKO_BIND_HIPSOLVER_POTRF_BUFFERSIZE(std::complex, + hipsolverDnZpotrf_bufferSize); +template +GKO_BIND_HIPSOLVER_POTRF_BUFFERSIZE(ValueType, detail::not_implemented); + +#undef GKO_BIND_HIPSOLVER_POTRF_BUFFERSIZE + + +#define GKO_BIND_HIPSOLVER_POTRF(ValueType, HipsolverName) \ + inline void potrf(hipsolverDnHandle_t handle, hipsolverFillMode_t uplo, \ + int32 n, ValueType* a, int32 lda, ValueType* work, \ + int32 buffer_num_elems, int32* dev_info) \ + { \ + GKO_ASSERT_NO_HIPSOLVER_ERRORS( \ + HipsolverName(handle, uplo, n, as_hiplibs_type(a), lda, \ + as_hiplibs_type(work), buffer_num_elems, dev_info)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HIPSOLVER_POTRF(float, hipsolverDnSpotrf); +GKO_BIND_HIPSOLVER_POTRF(double, hipsolverDnDpotrf); +GKO_BIND_HIPSOLVER_POTRF(std::complex, hipsolverDnCpotrf); +GKO_BIND_HIPSOLVER_POTRF(std::complex, hipsolverDnZpotrf); +template +GKO_BIND_HIPSOLVER_POTRF(ValueType, detail::not_implemented); + +#undef GKO_BIND_HIPSOLVER_POTRF + + +#define GKO_BIND_HIPSOLVER_POTRS(ValueType, HipsolverName) \ + inline void potrs(hipsolverDnHandle_t handle, hipsolverFillMode_t uplo, \ + int32 n, int32 nrhs, ValueType* a, int32 lda, \ + ValueType* b, int32 ldb, int32* dev_info) \ + { \ + GKO_ASSERT_NO_HIPSOLVER_ERRORS( \ + HipsolverName(handle, uplo, n, nrhs, as_hiplibs_type(a), lda, \ + as_hiplibs_type(b), ldb, dev_info)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HIPSOLVER_POTRS(float, hipsolverDnSpotrs); +GKO_BIND_HIPSOLVER_POTRS(double, hipsolverDnDpotrs); +GKO_BIND_HIPSOLVER_POTRS(std::complex, hipsolverDnCpotrs); +GKO_BIND_HIPSOLVER_POTRS(std::complex, hipsolverDnZpotrs); +template +GKO_BIND_HIPSOLVER_POTRS(ValueType, detail::not_implemented); + +#undef GKO_BIND_HIPSOLVER_POTRS + + +} // namespace hipsolver + + +namespace dev_lapack { + + +using namespace hipsolver; + + +#define LAPACK_EIG_TYPE_1 HIPSOLVER_EIG_TYPE_1 +#define LAPACK_EIG_TYPE_2 HIPSOLVER_EIG_TYPE_2 +#define LAPACK_EIG_TYPE_3 HIPSOLVER_EIG_TYPE_3 + +#define LAPACK_EIG_VECTOR HIPSOLVER_EIG_MODE_VECTOR +#define LAPACK_EIG_NOVECTOR HIPSOLVER_EIG_MODE_NOVECTOR + +#define LAPACK_FILL_UPPER HIPSOLVER_FILL_MODE_UPPER +#define LAPACK_FILL_LOWER HIPSOLVER_FILL_MODE_LOWER + + +} // namespace dev_lapack +} // namespace hip +} // namespace kernels +} // namespace gko + + +#endif // GKO_HIP_BASE_HIPSOLVER_BINDINGS_HPP_ diff --git a/hip/base/hipsolver_handle.hpp b/hip/base/hipsolver_handle.hpp new file mode 100644 index 00000000000..530742f825b --- /dev/null +++ b/hip/base/hipsolver_handle.hpp @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_HIP_BASE_HIPSOLVER_HANDLE_HPP_ +#define GKO_HIP_BASE_HIPSOLVER_HANDLE_HPP_ + +#if HIP_VERSION >= 50200000 +#include +#else +#include +#endif + +#include + + +namespace gko { +namespace kernels { +namespace hip { +/** + * @brief The HIPSOLVER namespace. + * + * @ingroup hipsolver + */ +namespace hipsolver { + + +inline hipsolverDnContext* init(hipStream_t stream) +{ + hipsolverDnHandle_t handle; + GKO_ASSERT_NO_HIPSOLVER_ERRORS(hipsolverDnCreate(&handle)); + GKO_ASSERT_NO_HIPSOLVER_ERRORS(hipsolverDnSetStream(handle, stream)); + return reinterpret_cast(handle); +} + + +inline void destroy_hipsolver_handle(hipsolverDnHandle_t handle) +{ + GKO_ASSERT_NO_HIPSOLVER_ERRORS( + hipsolverDnDestroy(reinterpret_cast(handle))); +} + + +} // namespace hipsolver +} // namespace hip +} // namespace kernels +} // namespace gko + + +#endif // GKO_HIP_BASE_HIPSOLVER_HANDLE_HPP_ diff --git a/hip/test/base/exception_helpers.hip.cpp b/hip/test/base/exception_helpers.hip.cpp index 85a28fc1c41..32e4b6c3136 100644 --- a/hip/test/base/exception_helpers.hip.cpp +++ b/hip/test/base/exception_helpers.hip.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -9,10 +9,16 @@ #include #include #include +#if GKO_HAVE_LAPACK +#include +#endif #else #include #include #include +#if GKO_HAVE_LAPACK +#include +#endif #endif @@ -70,4 +76,18 @@ TEST(AssertNoHipsparseErrors, DoesNotThrowOnSuccess) } +#if GKO_HAVE_LAPACK +TEST(AssertNoHipsolverErrors, ThrowsOnError) +{ + ASSERT_THROW(GKO_ASSERT_NO_HIPSOLVER_ERRORS(1), gko::HipsolverError); +} + + +TEST(AssertNoHipsolverErrors, DoesNotThrowOnSuccess) +{ + ASSERT_NO_THROW(GKO_ASSERT_NO_HIPSOLVER_ERRORS(HIPSOLVER_STATUS_SUCCESS)); +} +#endif + + } // namespace diff --git a/include/ginkgo/config.hpp.in b/include/ginkgo/config.hpp.in index 93dced98f5b..6a652987101 100644 --- a/include/ginkgo/config.hpp.in +++ b/include/ginkgo/config.hpp.in @@ -115,6 +115,12 @@ // clang-format on +/* Is LAPACK available? */ +// clang-format off +#define GKO_HAVE_LAPACK @GINKGO_HAVE_LAPACK@ +// clang-format on + + /* Is HWLOC available ? */ // clang-format off #define GKO_HAVE_HWLOC @GINKGO_HAVE_HWLOC@ diff --git a/include/ginkgo/core/base/exception.hpp b/include/ginkgo/core/base/exception.hpp index febc5e17034..ea8fcddc439 100644 --- a/include/ginkgo/core/base/exception.hpp +++ b/include/ginkgo/core/base/exception.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -258,6 +258,29 @@ class CusparseError : public Error { }; +/** + * CusolverError is thrown when a cuSOLVER routine throws a non-zero error code. + */ +class CusolverError : public Error { +public: + /** + * Initializes a cuSOLVER error. + * + * @param file The name of the offending source file + * @param line The source code line number where the error occurred + * @param func The name of the cuSOLVER routine that failed + * @param error_code The resulting cuSOLVER error code + */ + CusolverError(const std::string& file, int line, const std::string& func, + int64 error_code) + : Error(file, line, func + ": " + get_error(error_code)) + {} + +private: + static std::string get_error(int64 error_code); +}; + + /** * CufftError is thrown when a cuFFT routine throws a non-zero error code. */ @@ -374,6 +397,30 @@ class HipsparseError : public Error { }; +/** + * HipsolverError is thrown when a hipSOLVER routine throws a non-zero error + * code. + */ +class HipsolverError : public Error { +public: + /** + * Initializes a hipSOLVER error. + * + * @param file The name of the offending source file + * @param line The source code line number where the error ochiprred + * @param func The name of the hipSOLVER routine that failed + * @param error_code The resulting hipSOLVER error code + */ + HipsolverError(const std::string& file, int line, const std::string& func, + int64 error_code) + : Error(file, line, func + ": " + get_error(error_code)) + {} + +private: + static std::string get_error(int64 error_code); +}; + + /** * HipfftError is thrown when a hipFFT routine throws a non-zero error code. */ @@ -397,6 +444,26 @@ class HipfftError : public Error { }; +/** + * LapackError is thrown when a LAPACK routine throws a non-zero error code. + */ +class LapackError : public Error { +public: + /** + * Initializes a LAPACK error. + * + * @param file The name of the offending source file + * @param line The source code line number where the error occurred + * @param func The name of the LAPACK routine that failed + * @param error_code The resulting LAPACK error code + */ + LapackError(const std::string& file, int line, const std::string& func, + int64 error_code) + : Error(file, line, func + ": " + std::to_string(error_code)) + {} +}; + + /** * MetisError is thrown when METIS routine throws an error code. */ diff --git a/include/ginkgo/core/base/exception_helpers.hpp b/include/ginkgo/core/base/exception_helpers.hpp index f0104ba1a7c..c5441681ffc 100644 --- a/include/ginkgo/core/base/exception_helpers.hpp +++ b/include/ginkgo/core/base/exception_helpers.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -494,6 +494,15 @@ inline size_type get_num_batch_items(const T& obj) ::gko::CusparseError(__FILE__, __LINE__, __func__, _errcode) +/** + * Instantiates a CusolverError. + * + * @param errcode The error code returned from the cuSOLVER routine. + */ +#define GKO_CUSOLVER_ERROR(_errcode) \ + ::gko::CusolverError(__FILE__, __LINE__, __func__, _errcode) + + /** * Instantiates a CufftError. * @@ -503,6 +512,15 @@ inline size_type get_num_batch_items(const T& obj) ::gko::CufftError(__FILE__, __LINE__, __func__, _errcode) +/** + * Instantiates an LapackError. + * + * @param errcode The error code returned from the LAPACK routine. + */ +#define GKO_LAPACK_ERROR(_errcode) \ + ::gko::LapackError(__FILE__, __LINE__, __func__, _errcode) + + /** * Asserts that a CUDA library call completed without errors. * @@ -559,6 +577,20 @@ inline size_type get_num_batch_items(const T& obj) } while (false) +/** + * Asserts that a cuSOLVER library call completed without errors. + * + * @param _cusolver_call a library call expression + */ +#define GKO_ASSERT_NO_CUSOLVER_ERRORS(_cusolver_call) \ + do { \ + auto _errcode = _cusolver_call; \ + if (_errcode != CUSOLVER_STATUS_SUCCESS) { \ + throw GKO_CUSOLVER_ERROR(_errcode); \ + } \ + } while (false) + + /** * Asserts that a cuFFT library call completed without errors. * @@ -609,6 +641,15 @@ inline size_type get_num_batch_items(const T& obj) ::gko::HipsparseError(__FILE__, __LINE__, __func__, _errcode) +/** + * Instantiates a HipsolverError. + * + * @param errcode The error code returned from the hipSOLVER routine. + */ +#define GKO_HIPSOLVER_ERROR(_errcode) \ + ::gko::HipsolverError(__FILE__, __LINE__, __func__, _errcode) + + /** * Instantiates a HipfftError. * @@ -674,6 +715,20 @@ inline size_type get_num_batch_items(const T& obj) } while (false) +/** + * Asserts that a hipSOLVER library call completed without errors. + * + * @param _hipsolver_call a library call expression + */ +#define GKO_ASSERT_NO_HIPSOLVER_ERRORS(_hipsolver_call) \ + do { \ + auto _errcode = _hipsolver_call; \ + if (_errcode != HIPSOLVER_STATUS_SUCCESS) { \ + throw GKO_HIPSOLVER_ERROR(_errcode); \ + } \ + } while (false) + + /** * Asserts that a hipFFT library call completed without errors. * @@ -702,6 +757,22 @@ inline size_type get_num_batch_items(const T& obj) } while (false) +/** + * Asserts that an LAPACK library call completed without errors. + * + * @param _lapack_call a library call expression + * @param _info_var the name of the variable passed as LAPACK's + * INFO argument in _lapack_call + */ +#define GKO_ASSERT_NO_LAPACK_ERRORS(_lapack_call, _info_var) \ + do { \ + _lapack_call; \ + if (_info_var != 0) { \ + throw GKO_LAPACK_ERROR(_info_var); \ + } \ + } while (false) + + namespace detail { diff --git a/include/ginkgo/core/base/executor.hpp b/include/ginkgo/core/base/executor.hpp index 224860b72b7..5533ff35e07 100644 --- a/include/ginkgo/core/base/executor.hpp +++ b/include/ginkgo/core/base/executor.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -1688,6 +1688,20 @@ class CudaExecutor : public detail::ExecutorBase, return cusparse_handle_.get(); } + /** + * Get the cusolver handle for this executor + * + * @return the cusolver handle (cusolverDnContext*) for this executor + */ + cusolverDnContext* get_dev_lapack_handle() const + { +#if GKO_HAVE_LAPACK + return cusolver_handle_.get(); +#else + return nullptr; +#endif + } + /** * Get the closest PUs * @@ -1756,6 +1770,7 @@ class CudaExecutor : public detail::ExecutorBase, using handle_manager = std::unique_ptr>; handle_manager cublas_handle_; handle_manager cusparse_handle_; + handle_manager cusolver_handle_; std::shared_ptr alloc_; CUstream_st* stream_; }; @@ -1912,6 +1927,20 @@ class HipExecutor : public detail::ExecutorBase, return hipsparse_handle_.get(); } + /** + * Get the hipsolver handle for this executor + * + * @return the hipsolver handle (hipsolverDnContext*) for this executor + */ + hipsolverDnContext* get_dev_lapack_handle() const + { +#if GKO_HAVE_LAPACK + return hipsolver_handle_.get(); +#else + return nullptr; +#endif + } + /** * Get the closest NUMA node * @@ -1974,6 +2003,7 @@ class HipExecutor : public detail::ExecutorBase, using handle_manager = std::unique_ptr>; handle_manager hipblas_handle_; handle_manager hipsparse_handle_; + handle_manager hipsolver_handle_; std::shared_ptr alloc_; GKO_HIP_STREAM_STRUCT* stream_; }; diff --git a/include/ginkgo/core/base/fwd_decls.hpp b/include/ginkgo/core/base/fwd_decls.hpp index 84e579058c4..0e05a649ddc 100644 --- a/include/ginkgo/core/base/fwd_decls.hpp +++ b/include/ginkgo/core/base/fwd_decls.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -13,6 +13,8 @@ struct cublasContext; struct cusparseContext; +struct cusolverDnContext; + struct CUstream_st; struct CUevent_st; @@ -21,6 +23,8 @@ struct hipblasContext; struct hipsparseContext; +struct hipsolverDnContext; + #if GINKGO_HIP_PLATFORM_HCC struct ihipStream_t; struct ihipEvent_t; diff --git a/reference/CMakeLists.txt b/reference/CMakeLists.txt index 87858d18812..628954f237b 100644 --- a/reference/CMakeLists.txt +++ b/reference/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources( distributed/partition_helpers_kernels.cpp distributed/partition_kernels.cpp distributed/vector_kernels.cpp + eigensolver/lobpcg_kernels.cpp factorization/cholesky_kernels.cpp factorization/elimination_forest_kernels.cpp factorization/factorization_kernels.cpp @@ -73,6 +74,9 @@ target_sources( stop/residual_norm_kernels.cpp ) +if(GINKGO_BUILD_LAPACK) + target_link_libraries(ginkgo_reference PUBLIC LAPACK::LAPACK) +endif() target_link_libraries(ginkgo_reference PUBLIC ginkgo_device) target_compile_definitions( ginkgo_reference diff --git a/reference/base/blas_bindings.hpp b/reference/base/blas_bindings.hpp new file mode 100644 index 00000000000..1d64aecbd14 --- /dev/null +++ b/reference/base/blas_bindings.hpp @@ -0,0 +1,109 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_REFERENCE_BASE_BLAS_BINDINGS_HPP_ +#define GKO_REFERENCE_BASE_BLAS_BINDINGS_HPP_ + +#include + + +#if GKO_HAVE_LAPACK + + +extern "C" { + + +// Triangular matrix-matrix multiplication +void strmm(const char* side, const char* uplo, const char* transa, + const char* diag, const std::int32_t* m, const std::int32_t* n, + const float* alpha, const float* A, const std::int32_t* lda, + float* B, const std::int32_t* ldb); + +void dtrmm(const char* side, const char* uplo, const char* transa, + const char* diag, const std::int32_t* m, const std::int32_t* n, + const double* alpha, const double* A, const std::int32_t* lda, + double* B, const std::int32_t* ldb); + +void ctrmm(const char* side, const char* uplo, const char* transa, + const char* diag, const std::int32_t* m, const std::int32_t* n, + const std::complex* alpha, const std::complex* A, + const std::int32_t* lda, std::complex* B, + const std::int32_t* ldb); + +void ztrmm(const char* side, const char* uplo, const char* transa, + const char* diag, const std::int32_t* m, const std::int32_t* n, + const std::complex* alpha, const std::complex* A, + const std::int32_t* lda, std::complex* B, + const std::int32_t* ldb); +} + + +namespace gko { +namespace kernels { +namespace reference { +/** + * @brief The BLAS namespace. + * + * @ingroup lapack + */ +namespace blas { + + +template +struct is_supported : std::false_type {}; + +template <> +struct is_supported : std::true_type {}; + +template <> +struct is_supported : std::true_type {}; + +template <> +struct is_supported> : std::true_type {}; + +template <> +struct is_supported> : std::true_type {}; + + +#define GKO_BIND_TRMM(ValueType, BlasName) \ + inline void trmm(const char* side, const char* uplo, const char* transa, \ + const char* diag, const int32* m, const int32* n, \ + const ValueType* alpha, const ValueType* a, \ + const int32* lda, ValueType* b, const int32* ldb) \ + { \ + BlasName(side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_TRMM(float, strmm); +GKO_BIND_TRMM(double, dtrmm); +GKO_BIND_TRMM(std::complex, ctrmm); +GKO_BIND_TRMM(std::complex, ztrmm); +template +inline void trmm(const char* side, const char* uplo, const char* transa, + const char* diag, const int32* m, const int32* n, + const ValueType* alpha, const ValueType* a, const int32* lda, + ValueType* b, const int32* ldb) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_TRMM + + +#define BLAS_OP_N 'N' +#define BLAS_OP_T 'T' +#define BLAS_OP_C 'C' + +#define BLAS_SIDE_LEFT 'L' +#define BLAS_SIDE_RIGHT 'R' + + +} // namespace blas +} // namespace reference +} // namespace kernels +} // namespace gko + +#endif // GKO_HAVE_LAPACK + +#endif // GKO_REFERENCE_BASE_BLAS_BINDINGS_HPP_ diff --git a/reference/base/lapack_bindings.hpp b/reference/base/lapack_bindings.hpp new file mode 100644 index 00000000000..8f1b2f092f6 --- /dev/null +++ b/reference/base/lapack_bindings.hpp @@ -0,0 +1,460 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_REFERENCE_BASE_LAPACK_BINDINGS_HPP_ +#define GKO_REFERENCE_BASE_LAPACK_BINDINGS_HPP_ + +#include +#include +#include + + +#if GKO_HAVE_LAPACK + + +extern "C" { + + +// Symmetric eigenvalue problem +void ssyevd(const char* jobz, const char* uplo, const std::int32_t* n, float* A, + const std::int32_t* lda, float* w, float* work, std::int32_t* lwork, + std::int32_t* iwork, std::int32_t* liwork, std::int32_t* info); + +void dsyevd(const char* jobz, const char* uplo, const std::int32_t* n, + double* A, const std::int32_t* lda, double* w, double* work, + std::int32_t* lwork, std::int32_t* iwork, std::int32_t* liwork, + std::int32_t* info); + +void cheevd(const char* jobz, const char* uplo, const std::int32_t* n, + std::complex* A, const std::int32_t* lda, float* w, + std::complex* work, std::int32_t* lwork, float* rwork, + std::int32_t* lrwork, std::int32_t* iwork, std::int32_t* liwork, + std::int32_t* info); + +void zheevd(const char* jobz, const char* uplo, const std::int32_t* n, + std::complex* A, const std::int32_t* lda, double* w, + std::complex* work, std::int32_t* lwork, double* rwork, + std::int32_t* lrwork, std::int32_t* iwork, std::int32_t* liwork, + std::int32_t* info); + + +// Symmetric generalized eigenvalue problem +void ssygvd(const std::int32_t* itype, const char* jobz, const char* uplo, + const std::int32_t* n, float* A, const std::int32_t* lda, float* B, + const std::int32_t* ldb, float* w, float* work, std::int32_t* lwork, + std::int32_t* iwork, std::int32_t* liwork, std::int32_t* info); + +void dsygvd(const std::int32_t* itype, const char* jobz, const char* uplo, + const std::int32_t* n, double* A, const std::int32_t* lda, + double* B, const std::int32_t* ldb, double* w, double* work, + std::int32_t* lwork, std::int32_t* iwork, std::int32_t* liwork, + std::int32_t* info); + +void chegvd(const std::int32_t* itype, const char* jobz, const char* uplo, + const std::int32_t* n, std::complex* A, + const std::int32_t* lda, std::complex* B, + const std::int32_t* ldb, float* w, std::complex* work, + std::int32_t* lwork, float* rwork, std::int32_t* lrwork, + std::int32_t* iwork, std::int32_t* liwork, std::int32_t* info); + +void zhegvd(const std::int32_t* itype, const char* jobz, const char* uplo, + const std::int32_t* n, std::complex* A, + const std::int32_t* lda, std::complex* B, + const std::int32_t* ldb, double* w, std::complex* work, + std::int32_t* lwork, double* rwork, std::int32_t* lrwork, + std::int32_t* iwork, std::int32_t* liwork, std::int32_t* info); + + +// Cholesky factorization +void spotrf(const char* uplo, const std::int32_t* n, float* A, + const std::int32_t* lda, std::int32_t* info); + +void dpotrf(const char* uplo, const std::int32_t* n, double* A, + const std::int32_t* lda, std::int32_t* info); + +void cpotrf(const char* uplo, const std::int32_t* n, std::complex* A, + const std::int32_t* lda, std::int32_t* info); + +void zpotrf(const char* uplo, const std::int32_t* n, std::complex* A, + const std::int32_t* lda, std::int32_t* info); + + +// Triangular matrix inverse +void strtri(const char* uplo, const char* diag, const std::int32_t* n, float* A, + const std::int32_t* lda, std::int32_t* info); + +void dtrtri(const char* uplo, const char* diag, const std::int32_t* n, + double* A, const std::int32_t* lda, std::int32_t* info); + +void ctrtri(const char* uplo, const char* diag, const std::int32_t* n, + std::complex* A, const std::int32_t* lda, + std::int32_t* info); + +void ztrtri(const char* uplo, const char* diag, const std::int32_t* n, + std::complex* A, const std::int32_t* lda, + std::int32_t* info); +} + + +namespace gko { +namespace kernels { +namespace reference { +/** + * @brief The LAPACK namespace. + * + * @ingroup lapack + */ +namespace lapack { + + +template +struct is_supported : std::false_type {}; + +template <> +struct is_supported : std::true_type {}; + +template <> +struct is_supported : std::true_type {}; + +template <> +struct is_supported> : std::true_type {}; + +template <> +struct is_supported> : std::true_type {}; + + +#define GKO_BIND_SYEVD_BUFFERSIZES(ValueType, LapackName) \ + inline void syevd_buffersizes( \ + const char* jobz, const char* uplo, const int32* n, ValueType* a, \ + const int32* lda, ValueType* w, ValueType* work, \ + int32* fp_buffer_num_elems, int32* iwork, int32* int_buffer_num_elems) \ + { \ + int32 info; \ + *fp_buffer_num_elems = -1; \ + *int_buffer_num_elems = -1; \ + GKO_ASSERT_NO_LAPACK_ERRORS( \ + LapackName(jobz, uplo, n, a, lda, w, work, fp_buffer_num_elems, \ + iwork, int_buffer_num_elems, &info), \ + info); \ + *fp_buffer_num_elems = static_cast(work[0]); \ + *int_buffer_num_elems = iwork[0]; \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_SYEVD_BUFFERSIZES(float, ssyevd); +GKO_BIND_SYEVD_BUFFERSIZES(double, dsyevd); +template +inline void syevd_buffersizes(const char* jobz, const char* uplo, + const int32* n, ValueType* a, const int32* lda, + ValueType* w, ValueType* work, + int32* fp_buffer_num_elems, int32* iwork, + int32* int_buffer_num_elems) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_SYEVD_BUFFERSIZES + + +#define GKO_BIND_SYEVD(ValueType, LapackName) \ + inline void syevd(const char* jobz, const char* uplo, const int32* n, \ + ValueType* a, const int32* lda, ValueType* w, \ + ValueType* work, int32* fp_buffer_num_elems, \ + int32* iwork, int32* int_buffer_num_elems) \ + { \ + int32 info; \ + GKO_ASSERT_NO_LAPACK_ERRORS( \ + LapackName(jobz, uplo, n, a, lda, w, work, fp_buffer_num_elems, \ + iwork, int_buffer_num_elems, &info), \ + info); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_SYEVD(float, ssyevd); +GKO_BIND_SYEVD(double, dsyevd); +template +inline void syevd(const char* jobz, const char* uplo, const int32* n, + ValueType* a, const int32* lda, ValueType* w, ValueType* work, + int32* fp_buffer_num_elems, int32* iwork, + int32* int_buffer_num_elems) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_SYEVD + + +#define GKO_BIND_HEEVD_BUFFERSIZES(ValueType, LapackName) \ + inline void heevd_buffersizes( \ + const char* jobz, const char* uplo, const int32* n, ValueType* a, \ + const int32* lda, gko::remove_complex* w, ValueType* work, \ + int32* fp_buffer_num_elems, gko::remove_complex* rwork, \ + int32* rfp_buffer_num_elems, int32* iwork, \ + int32* int_buffer_num_elems) \ + { \ + int32 info; \ + *fp_buffer_num_elems = -1; \ + *rfp_buffer_num_elems = -1; \ + *int_buffer_num_elems = -1; \ + GKO_ASSERT_NO_LAPACK_ERRORS( \ + LapackName(jobz, uplo, n, a, lda, w, work, fp_buffer_num_elems, \ + rwork, rfp_buffer_num_elems, iwork, \ + int_buffer_num_elems, &info), \ + info); \ + *fp_buffer_num_elems = static_cast(work[0].real()); \ + *rfp_buffer_num_elems = static_cast(rwork[0]); \ + *int_buffer_num_elems = iwork[0]; \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HEEVD_BUFFERSIZES(std::complex, cheevd); +GKO_BIND_HEEVD_BUFFERSIZES(std::complex, zheevd); +template +inline void heevd_buffersizes(const char* jobz, const char* uplo, + const int32* n, ValueType* a, const int32* lda, + gko::remove_complex* w, + ValueType* work, int32* fp_buffer_num_elems, + gko::remove_complex* rwork, + int32* rfp_buffer_num_elems, int32* iwork, + int32* int_buffer_num_elems) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_HEEVD_BUFFERSIZES + + +#define GKO_BIND_HEEVD(ValueType, LapackName) \ + inline void heevd( \ + const char* jobz, const char* uplo, const int32* n, ValueType* a, \ + const int32* lda, gko::remove_complex* w, ValueType* work, \ + int32* fp_buffer_num_elems, gko::remove_complex* rwork, \ + int32* rfp_buffer_num_elems, int32* iwork, \ + int32* int_buffer_num_elems) \ + { \ + int32 info; \ + GKO_ASSERT_NO_LAPACK_ERRORS( \ + LapackName(jobz, uplo, n, a, lda, w, work, fp_buffer_num_elems, \ + rwork, rfp_buffer_num_elems, iwork, \ + int_buffer_num_elems, &info), \ + info); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HEEVD(std::complex, cheevd); +GKO_BIND_HEEVD(std::complex, zheevd); +template +inline void heevd(const char* jobz, const char* uplo, const int32* n, + ValueType* a, const int32* lda, remove_complex* w, + ValueType* work, int32* fp_buffer_num_elems, + remove_complex* rwork, int32* rfp_buffer_num_elems, + int32* iwork, + int32* int_buffer_num_elems) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_HEEVD + + +#define GKO_BIND_SYGVD_BUFFERSIZES(ValueType, LapackName) \ + inline void sygvd_buffersizes( \ + const int32* itype, const char* jobz, const char* uplo, \ + const int32* n, ValueType* a, const int32* lda, ValueType* b, \ + const int32* ldb, ValueType* w, ValueType* work, \ + int32* fp_buffer_num_elems, int32* iwork, int32* int_buffer_num_elems) \ + { \ + int32 info; \ + *fp_buffer_num_elems = -1; \ + *int_buffer_num_elems = -1; \ + GKO_ASSERT_NO_LAPACK_ERRORS( \ + LapackName(itype, jobz, uplo, n, a, lda, b, ldb, w, work, \ + fp_buffer_num_elems, iwork, int_buffer_num_elems, \ + &info), \ + info); \ + *fp_buffer_num_elems = static_cast(work[0]); \ + *int_buffer_num_elems = iwork[0]; \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_SYGVD_BUFFERSIZES(float, ssygvd); +GKO_BIND_SYGVD_BUFFERSIZES(double, dsygvd); +template +inline void sygvd_buffersizes(const int32* itype, const char* jobz, + const char* uplo, const int32* n, ValueType* a, + const int32* lda, ValueType* b, const int32* ldb, + ValueType* w, ValueType* work, + int32* fp_buffer_num_elems, int32* iwork, + int32* int_buffer_num_elems) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_SYGVD_BUFFERSIZES + + +#define GKO_BIND_SYGVD(ValueType, LapackName) \ + inline void sygvd(const int32* itype, const char* jobz, const char* uplo, \ + const int32* n, ValueType* a, const int32* lda, \ + ValueType* b, const int32* ldb, ValueType* w, \ + ValueType* work, int32* fp_buffer_num_elems, \ + int32* iwork, int32* int_buffer_num_elems) \ + { \ + int32 info; \ + GKO_ASSERT_NO_LAPACK_ERRORS( \ + LapackName(itype, jobz, uplo, n, a, lda, b, ldb, w, work, \ + fp_buffer_num_elems, iwork, int_buffer_num_elems, \ + &info), \ + info); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_SYGVD(float, ssygvd); +GKO_BIND_SYGVD(double, dsygvd); +template +inline void sygvd(const int32* itype, const char* jobz, const char* uplo, + const int32* n, ValueType* a, const int32* lda, ValueType* b, + const int32* ldb, ValueType* w, ValueType* work, + int32* fp_buffer_num_elems, int32* iwork, + int32* int_buffer_num_elems) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_SYGVD + + +#define GKO_BIND_HEGVD_BUFFERSIZES(ValueType, LapackName) \ + inline void hegvd_buffersizes( \ + const int32* itype, const char* jobz, const char* uplo, \ + const int32* n, ValueType* a, const int32* lda, ValueType* b, \ + const int32* ldb, gko::remove_complex* w, ValueType* work, \ + int32* fp_buffer_num_elems, gko::remove_complex* rwork, \ + int32* rfp_buffer_num_elems, int32* iwork, \ + int32* int_buffer_num_elems) \ + { \ + int32 info; \ + *fp_buffer_num_elems = -1; \ + *rfp_buffer_num_elems = -1; \ + *int_buffer_num_elems = -1; \ + GKO_ASSERT_NO_LAPACK_ERRORS( \ + LapackName(itype, jobz, uplo, n, a, lda, b, ldb, w, work, \ + fp_buffer_num_elems, rwork, rfp_buffer_num_elems, \ + iwork, int_buffer_num_elems, &info), \ + info); \ + *fp_buffer_num_elems = static_cast(work[0].real()); \ + *rfp_buffer_num_elems = static_cast(rwork[0]); \ + *int_buffer_num_elems = iwork[0]; \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HEGVD_BUFFERSIZES(std::complex, chegvd); +GKO_BIND_HEGVD_BUFFERSIZES(std::complex, zhegvd); +template +inline void hegvd_buffersizes(const int32* itype, const char* jobz, + const char* uplo, const int32* n, ValueType* a, + const int32* lda, ValueType* b, const int32* ldb, + gko::remove_complex* w, + ValueType* work, int32* fp_buffer_num_elems, + gko::remove_complex* rwork, + int32* rfp_buffer_num_elems, int32* iwork, + int32* int_buffer_num_elems) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_HEGVD_BUFFERSIZES + + +#define GKO_BIND_HEGVD(ValueType, LapackName) \ + inline void hegvd( \ + const int32* itype, const char* jobz, const char* uplo, \ + const int32* n, ValueType* a, const int32* lda, ValueType* b, \ + const int32* ldb, remove_complex* w, ValueType* work, \ + int32* fp_buffer_num_elems, remove_complex* rwork, \ + int32* rfp_buffer_num_elems, int32* iwork, \ + int32* int_buffer_num_elems) \ + { \ + int32 info; \ + GKO_ASSERT_NO_LAPACK_ERRORS( \ + LapackName(itype, jobz, uplo, n, a, lda, b, ldb, w, work, \ + fp_buffer_num_elems, rwork, rfp_buffer_num_elems, \ + iwork, int_buffer_num_elems, &info), \ + info); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_HEGVD(std::complex, chegvd); +GKO_BIND_HEGVD(std::complex, zhegvd); +template +inline void hegvd(const int32* itype, const char* jobz, const char* uplo, + const int32* n, ValueType* a, const int32* lda, ValueType* b, + const int32* ldb, remove_complex* w, + ValueType* work, int32* fp_buffer_num_elems, + remove_complex* rwork, int32* rfp_buffer_num_elems, + int32* iwork, + int32* int_buffer_num_elems) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_HEGVD + + +#define GKO_BIND_POTRF(ValueType, LapackName) \ + inline void potrf(const char* uplo, const int32* n, ValueType* a, \ + const int32* lda) \ + { \ + int32 info; \ + GKO_ASSERT_NO_LAPACK_ERRORS(LapackName(uplo, n, a, lda, &info), info); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_POTRF(float, spotrf); +GKO_BIND_POTRF(double, dpotrf); +GKO_BIND_POTRF(std::complex, cpotrf); +GKO_BIND_POTRF(std::complex, zpotrf); +template +inline void potrf(const char* uplo, const int32* n, ValueType* a, + const int32* lda) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_POTRF + + +#define GKO_BIND_TRTRI(ValueType, LapackName) \ + inline void trtri(const char* uplo, const char* diag, const int32* n, \ + ValueType* a, const int32* lda) \ + { \ + int32 info; \ + GKO_ASSERT_NO_LAPACK_ERRORS(LapackName(uplo, diag, n, a, lda, &info), \ + info); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_TRTRI(float, strtri); +GKO_BIND_TRTRI(double, dtrtri); +GKO_BIND_TRTRI(std::complex, ctrtri); +GKO_BIND_TRTRI(std::complex, ztrtri); +template +inline void trtri(const char* uplo, const char* diag, const int32* n, + ValueType* a, const int32* lda) GKO_NOT_IMPLEMENTED; + +#undef GKO_BIND_TRTRI + + +#define LAPACK_EIG_VECTOR 'V' +#define LAPACK_EIG_NOVECTOR 'N' + +#define LAPACK_FILL_UPPER 'U' +#define LAPACK_FILL_LOWER 'L' + +#define LAPACK_DIAG_UNIT 'U' +#define LAPACK_DIAG_NONUNIT 'N' + + +} // namespace lapack +} // namespace reference +} // namespace kernels +} // namespace gko + + +#endif // GKO_HAVE_LAPACK + +#endif // GKO_REFERENCE_BASE_LAPACK_BINDINGS_HPP_ diff --git a/reference/eigensolver/lobpcg_kernels.cpp b/reference/eigensolver/lobpcg_kernels.cpp new file mode 100644 index 00000000000..b3c34288ba2 --- /dev/null +++ b/reference/eigensolver/lobpcg_kernels.cpp @@ -0,0 +1,275 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "core/eigensolver/lobpcg_kernels.hpp" + +#include +#include + +#include "reference/base/blas_bindings.hpp" +#include "reference/base/lapack_bindings.hpp" + +#if GKO_HAVE_LAPACK + + +namespace gko { +namespace kernels { +namespace reference { +/** + * @brief The LOBPCG solver namespace. + * + * @ingroup lobpcg + */ +namespace lobpcg { + + +template +void symm_eig(std::shared_ptr exec, + matrix::Dense* a, + array>* e_vals, array* workspace) +{ + constexpr auto max = std::numeric_limits::max(); + if (a->get_size()[1] > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + if (a->get_stride() > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + const int32 n = static_cast(a->get_size()[0]); + const int32 lda = static_cast(a->get_stride()); + const char job = LAPACK_EIG_VECTOR; + const char uplo = LAPACK_FILL_LOWER; + + if constexpr (!gko::is_complex_s::value) { + // Even if the workspace is already allocated, we need to know where to + // set the pointers for the individual workspaces of LAPACK + int32 fp_buffer_num_elems; + int32 int_buffer_num_elems; + ValueType* work = reinterpret_cast(workspace->get_data()); + array tmp_iwork(exec, 1); + lapack::syevd_buffersizes( + &job, &uplo, &n, a->get_values(), &lda, e_vals->get_data(), work, + &fp_buffer_num_elems, tmp_iwork.get_data(), &int_buffer_num_elems); + size_type total_bytes = sizeof(ValueType) * fp_buffer_num_elems + + sizeof(int32) * int_buffer_num_elems; + if (workspace->get_size() < total_bytes) { + workspace->resize_and_reset(total_bytes); + } + work = reinterpret_cast(workspace->get_data()); + // Set iwork pointer inside the workspace array + int32* iwork = reinterpret_cast( + workspace->get_data() + sizeof(ValueType) * fp_buffer_num_elems); + lapack::syevd(&job, &uplo, &n, a->get_values(), &lda, + e_vals->get_data(), work, &fp_buffer_num_elems, iwork, + &int_buffer_num_elems); + } else { // Complex data type + + // LAPACK expects column-major data, so we need to take the conjugate + // of the input matrix (same as performing A = A^T) + ValueType* data = a->get_values(); + for (int32 row = 0; row < n; ++row) { + for (int32 col = 0; col < n; ++col) { + data[row * lda + col] = conj(data[row * lda + col]); + } + } + + int32 fp_buffer_num_elems; + int32 rfp_buffer_num_elems; + int32 int_buffer_num_elems; + ValueType* work = reinterpret_cast(workspace->get_data()); + array tmp_iwork(exec, 1); + array> tmp_rwork(exec, 1); + lapack::heevd_buffersizes( + &job, &uplo, &n, a->get_values(), &lda, e_vals->get_data(), work, + &fp_buffer_num_elems, tmp_rwork.get_data(), &rfp_buffer_num_elems, + tmp_iwork.get_data(), &int_buffer_num_elems); + size_type total_bytes = + sizeof(ValueType) * fp_buffer_num_elems + + sizeof(remove_complex) * rfp_buffer_num_elems + + sizeof(int32) * int_buffer_num_elems; + if (workspace->get_size() < total_bytes) { + workspace->resize_and_reset(total_bytes); + } + work = reinterpret_cast(workspace->get_data()); + // Set rwork and iwork pointers inside the workspace array + remove_complex* rwork = + reinterpret_cast*>( + workspace->get_data() + + sizeof(ValueType) * fp_buffer_num_elems); + int32* iwork = reinterpret_cast( + workspace->get_data() + sizeof(ValueType) * fp_buffer_num_elems + + sizeof(remove_complex) * rfp_buffer_num_elems); + lapack::heevd(&job, &uplo, &n, a->get_values(), &lda, + e_vals->get_data(), work, &fp_buffer_num_elems, rwork, + &rfp_buffer_num_elems, iwork, &int_buffer_num_elems); + } +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_LOBPCG_SYMM_EIG_KERNEL); + + +template +void symm_generalized_eig(std::shared_ptr exec, + matrix::Dense* a, + matrix::Dense* b, + array>* e_vals, + array* workspace) +{ + constexpr auto max = std::numeric_limits::max(); + if (a->get_size()[1] > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + if (a->get_stride() > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + if (b->get_stride() > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + const int32 n = static_cast(a->get_size()[0]); + const int32 lda = static_cast(a->get_stride()); + const int32 ldb = static_cast(b->get_stride()); + const int32 itype = 1; + const char job = LAPACK_EIG_VECTOR; + const char uplo = LAPACK_FILL_LOWER; + + if constexpr (!gko::is_complex_s::value) { + // Even if the workspace is already allocated, we need to know where to + // set the pointers for the individual workspaces of LAPACK + int32 fp_buffer_num_elems; + int32 int_buffer_num_elems; + ValueType* work = reinterpret_cast(workspace->get_data()); + array tmp_iwork(exec, 1); + lapack::sygvd_buffersizes( + &itype, &job, &uplo, &n, a->get_values(), &lda, b->get_values(), + &ldb, e_vals->get_data(), work, &fp_buffer_num_elems, + tmp_iwork.get_data(), &int_buffer_num_elems); + size_type total_bytes = sizeof(ValueType) * fp_buffer_num_elems + + sizeof(int32) * int_buffer_num_elems; + if (workspace->get_size() < total_bytes) { + workspace->resize_and_reset(total_bytes); + } + work = reinterpret_cast(workspace->get_data()); + // Set iwork pointer inside the workspace array + int32* iwork = reinterpret_cast( + workspace->get_data() + sizeof(ValueType) * fp_buffer_num_elems); + lapack::sygvd(&itype, &job, &uplo, &n, a->get_values(), &lda, + b->get_values(), &ldb, e_vals->get_data(), work, + &fp_buffer_num_elems, iwork, &int_buffer_num_elems); + } else { // Complex data type + + // LAPACK expects column-major data, so we need to take the conjugate + // of the input matrices (same as performing A = A^T) + ValueType* a_data = a->get_values(); + ValueType* b_data = b->get_values(); + for (int32 row = 0; row < n; ++row) { + for (int32 col = 0; col < n; ++col) { + a_data[row * lda + col] = conj(a_data[row * lda + col]); + b_data[row * lda + col] = conj(b_data[row * lda + col]); + } + } + + int32 fp_buffer_num_elems; + int32 rfp_buffer_num_elems; + int32 int_buffer_num_elems; + ValueType* work = reinterpret_cast(workspace->get_data()); + array tmp_iwork(exec, 1); + array> tmp_rwork(exec, 1); + lapack::hegvd_buffersizes( + &itype, &job, &uplo, &n, a->get_values(), &lda, b->get_values(), + &ldb, e_vals->get_data(), work, &fp_buffer_num_elems, + tmp_rwork.get_data(), &rfp_buffer_num_elems, tmp_iwork.get_data(), + &int_buffer_num_elems); + size_type total_bytes = + sizeof(ValueType) * fp_buffer_num_elems + + sizeof(remove_complex) * rfp_buffer_num_elems + + sizeof(int32) * int_buffer_num_elems; + if (workspace->get_size() < total_bytes) { + workspace->resize_and_reset(total_bytes); + } + work = reinterpret_cast(workspace->get_data()); + // Set rwork and iwork pointers inside the workspace array + remove_complex* rwork = + reinterpret_cast*>( + workspace->get_data() + + sizeof(ValueType) * fp_buffer_num_elems); + int32* iwork = reinterpret_cast( + workspace->get_data() + sizeof(ValueType) * fp_buffer_num_elems + + sizeof(remove_complex) * rfp_buffer_num_elems); + lapack::hegvd(&itype, &job, &uplo, &n, a->get_values(), &lda, + b->get_values(), &ldb, e_vals->get_data(), work, + &fp_buffer_num_elems, rwork, &rfp_buffer_num_elems, iwork, + &int_buffer_num_elems); + } +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( + GKO_DECLARE_LOBPCG_SYMM_GENERALIZED_EIG_KERNEL); + + +template +void b_orthonormalize(std::shared_ptr exec, + matrix::Dense* a, LinOp* b, + array* workspace) // (unused; for [cu/hip]SOLVER) +{ + constexpr auto max = std::numeric_limits::max(); + if (a->get_size()[0] > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + if (a->get_stride() > max) { + throw OverflowError(__FILE__, __LINE__, + name_demangling::get_type_name(typeid(int32))); + } + const int32 lda = static_cast(a->get_stride()); + + // Compute A^H * B * A + auto b_a = matrix::Dense::create( + exec, gko::dim<2>{b->get_size()[0], a->get_size()[1]}); + b->apply(a, b_a); + auto aH_b_a = matrix::Dense::create( + exec, gko::dim<2>{a->get_size()[1], a->get_size()[1]}); + gko::as>(a->conj_transpose())->apply(b_a, aH_b_a); + + const int32 n = static_cast(aH_b_a->get_size()[0]); + const int32 ldaH_b_a = static_cast(aH_b_a->get_stride()); + + // Cholesky + // Since LAPACK expects column-major, on exit, we will have + // L such that LL^H = A^T, i.e., the complex conjugate of the + // lower Cholesky factor, in column-major order. + const char uplo = LAPACK_FILL_LOWER; + lapack::potrf(&uplo, &n, aH_b_a->get_values(), &ldaH_b_a); + + // Invert the Cholesky factor: on exit, have conj(L)^{-1} + const char diag = LAPACK_DIAG_NONUNIT; + lapack::trtri(&uplo, &diag, &n, aH_b_a->get_values(), &ldaH_b_a); + + // A = A * (L^{-1})^H + // Since A is seen by BLAS as column-major, the operation + // A^T_{ij} = M_{ik} A^T_{kj}, with M = conj(L)^{-1} (col-major), + // is equivalent to A_{ji} = A_{jk} M^T_{ki} = + // A = A * L^{-H} (in row-major order). + const char side = BLAS_SIDE_LEFT; + const ValueType alpha = gko::one(); + const char transa = BLAS_OP_N; + const int32 m = static_cast(a->get_size()[0]); + // m & n swapped because of interpreting as col-major + blas::trmm(&side, &uplo, &transa, &diag, &n, &m, &alpha, + aH_b_a->get_const_values(), &ldaH_b_a, a->get_values(), &lda); +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_LOBPCG_B_ORTHONORMALIZE_KERNEL); + + +} // namespace lobpcg +} // namespace reference +} // namespace kernels +} // namespace gko + +#endif // GKO_HAVE_LAPACK diff --git a/reference/test/CMakeLists.txt b/reference/test/CMakeLists.txt index b7cb46408b5..9cc4be62b2c 100644 --- a/reference/test/CMakeLists.txt +++ b/reference/test/CMakeLists.txt @@ -3,6 +3,9 @@ include(${PROJECT_SOURCE_DIR}/cmake/create_test.cmake) add_subdirectory(base) add_subdirectory(components) add_subdirectory(distributed) +if(GINKGO_BUILD_LAPACK) + add_subdirectory(eigensolver) +endif() add_subdirectory(factorization) add_subdirectory(log) add_subdirectory(matrix) diff --git a/reference/test/base/CMakeLists.txt b/reference/test/base/CMakeLists.txt index fd3afd45ca8..2370f33ed6f 100644 --- a/reference/test/base/CMakeLists.txt +++ b/reference/test/base/CMakeLists.txt @@ -2,6 +2,7 @@ ginkgo_create_test(array EXECUTABLE_NAME array_test) # array collides with C++ s ginkgo_create_test(batch_multi_vector_kernels) ginkgo_create_test(combination) ginkgo_create_test(composition) +ginkgo_create_test(exception_helpers) ginkgo_create_test(index_set) ginkgo_create_test(perturbation) ginkgo_create_test(utils) diff --git a/reference/test/base/exception_helpers.cpp b/reference/test/base/exception_helpers.cpp new file mode 100644 index 00000000000..5b5260d61cc --- /dev/null +++ b/reference/test/base/exception_helpers.cpp @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + +#include +#include + + +namespace { + + +TEST(AssertNoLapackErrors, ThrowsOnError) +{ + gko::int32 info; + ASSERT_THROW(GKO_ASSERT_NO_LAPACK_ERRORS(info = 1, info), gko::LapackError); +} + + +TEST(AssertNoLapackErrors, DoesNotThrowOnSuccess) +{ + gko::int32 info; + ASSERT_NO_THROW(GKO_ASSERT_NO_LAPACK_ERRORS(info = 0, info)); +} + + +} // namespace diff --git a/reference/test/eigensolver/CMakeLists.txt b/reference/test/eigensolver/CMakeLists.txt new file mode 100644 index 00000000000..93f92da1a58 --- /dev/null +++ b/reference/test/eigensolver/CMakeLists.txt @@ -0,0 +1 @@ +ginkgo_create_test(lobpcg_kernels) diff --git a/reference/test/eigensolver/lobpcg_kernels.cpp b/reference/test/eigensolver/lobpcg_kernels.cpp new file mode 100644 index 00000000000..4820e1a3039 --- /dev/null +++ b/reference/test/eigensolver/lobpcg_kernels.cpp @@ -0,0 +1,263 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "core/eigensolver/lobpcg_kernels.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "core/test/utils.hpp" +#include "reference/base/lapack_bindings.hpp" + + +template +class Lobpcg : public ::testing::Test { +protected: + using value_type = T; + using rc_value_type = gko::remove_complex; + using cmplx_value_type = gko::to_complex; + using Mtx_r = gko::matrix::Dense; + using Mtx_c = gko::to_complex; + using Mtx = gko::matrix::Dense; + using Ary_r = gko::array; + using Ary = gko::array; + Lobpcg() : exec(gko::ReferenceExecutor::create()) + { + small_a_r = gko::initialize({{13.2, -4.3, -1.8, 0.12}, + {-4.3, 24.2, -1.7, -2.3}, + {-1.8, -1.7, 18.7, -0.8}, + {0.12, -2.3, -0.8, 10.0}}, + exec); + small_b_r = gko::initialize({{2.0, -1.1, 0.3, 0.01}, + {-1.1, 2.5, -0.8, 0.5}, + {0.3, -0.8, 2.3, -0.2}, + {0.01, 0.5, -0.2, 1.9}}, + exec); + small_a_cmplx = gko::initialize( + {{cmplx_value_type{13.2, 0.0}, cmplx_value_type{-4.3, -0.3}, + cmplx_value_type{-1.8, 1.12}, cmplx_value_type{0.12, 0.6}}, + {cmplx_value_type{-4.3, 0.3}, cmplx_value_type{24.2, 0.0}, + cmplx_value_type{-1.7, -2.2}, cmplx_value_type{-2.3, -0.55}}, + {cmplx_value_type{-1.8, -1.12}, cmplx_value_type{-1.7, 2.2}, + cmplx_value_type{18.7, 0.0}, cmplx_value_type{-0.8, -1.18}}, + {cmplx_value_type{0.12, -0.6}, cmplx_value_type{-2.3, 0.55}, + cmplx_value_type{-0.8, 1.18}, cmplx_value_type{10.0, 0.0}}}, + exec); + small_b_cmplx = gko::initialize( + {{cmplx_value_type{2.0, 0.0}, cmplx_value_type{-1.1, -0.1}, + cmplx_value_type{0.3, 0.12}, cmplx_value_type{0.01, 0.4}}, + {cmplx_value_type{-1.1, 0.1}, cmplx_value_type{2.5, 0.0}, + cmplx_value_type{-0.8, -0.18}, cmplx_value_type{0.5, -0.097}}, + {cmplx_value_type{0.3, -0.12}, cmplx_value_type{-0.8, 0.18}, + cmplx_value_type{2.3, 0.0}, cmplx_value_type{-0.2, -0.172}}, + {cmplx_value_type{0.01, -0.4}, cmplx_value_type{0.5, 0.097}, + cmplx_value_type{-0.2, 0.172}, cmplx_value_type{1.9, 0.0}}}, + exec); + small_e_vals = Ary_r(exec, 4); + } + + std::shared_ptr exec; + std::shared_ptr small_a_r; + std::shared_ptr small_b_r; + std::shared_ptr small_a_cmplx; + std::shared_ptr small_b_cmplx; + Ary_r small_e_vals; +}; + +TYPED_TEST_SUITE(Lobpcg, gko::test::ValueTypesBase, TypenameNameGenerator); + + +TYPED_TEST(Lobpcg, KernelSymmEig) +{ + using Mtx_r = typename TestFixture::Mtx_r; + using Mtx_c = typename TestFixture::Mtx_c; + using Mtx = typename TestFixture::Mtx; + using Ary_r = typename TestFixture::Ary_r; + using Ary = typename TestFixture::Ary; + using value_type = typename TestFixture::value_type; + + auto work = gko::array(this->exec, 1); + std::shared_ptr small_a; + std::shared_ptr small_a_copy; + + if constexpr (gko::is_complex_s::value) { + small_a_copy = gko::clone(this->small_a_cmplx); + small_a = this->small_a_cmplx; + } else { + small_a_copy = gko::clone(this->small_a_r); + small_a = this->small_a_r; + } + + gko::kernels::reference::lobpcg::symm_eig(this->exec, small_a.get(), + &(this->small_e_vals), &work); + + // On exit, the eigenvectors will be stored in the rows of the A matrix. + // We create submatrices for the vectors to check that A * x = lambda * x + // for each vector. + for (gko::size_type i = 0; i < this->small_e_vals.get_size(); i++) { + auto evec = gko::share(Mtx::create( + this->exec, gko::dim<2>{this->small_e_vals.get_size(), 1}, + Ary::view( + this->exec, this->small_e_vals.get_size(), + small_a->get_values() + i * this->small_e_vals.get_size()), + 1)); + + auto lambda_r = gko::share(Mtx_r::create( + this->exec, gko::dim<2>{1, 1}, + Ary_r::view(this->exec, 1, this->small_e_vals.get_data() + i), 1)); + std::shared_ptr lambda; + if constexpr (gko::is_complex_s::value) { + lambda = lambda_r->make_complex(); + } else { + lambda = lambda_r; + } + // A*x = lambda * x; + auto a_x = Mtx::create(this->exec, + gko::dim<2>{this->small_e_vals.get_size(), 1}); + // a_x = A * x + small_a_copy->apply(evec, a_x); + // scale x by lambda + evec->scale(lambda); + + GKO_ASSERT_MTX_NEAR(a_x, evec, r::value); + } +} + + +TYPED_TEST(Lobpcg, KernelSymmGeneralizedEig) +{ + using Mtx_r = typename TestFixture::Mtx_r; + using Mtx_c = typename TestFixture::Mtx_c; + using Mtx = typename TestFixture::Mtx; + using Ary_r = typename TestFixture::Ary_r; + using Ary = typename TestFixture::Ary; + using value_type = typename TestFixture::value_type; + + auto work = gko::array(this->exec, 1); + std::shared_ptr small_a; + std::shared_ptr small_b; + // Both A and B will be overwritten by the LAPACK call; store copies for + // the final check. + std::shared_ptr small_a_copy; + std::shared_ptr small_b_copy; + + if constexpr (gko::is_complex_s::value) { + small_a_copy = gko::clone(this->small_a_cmplx); + small_b_copy = gko::clone(this->small_b_cmplx); + small_a = this->small_a_cmplx; + small_b = this->small_b_cmplx; + } else { + small_a_copy = gko::clone(this->small_a_r); + small_b_copy = gko::clone(this->small_b_r); + small_a = this->small_a_r; + small_b = this->small_b_r; + } + + gko::kernels::reference::lobpcg::symm_generalized_eig( + this->exec, small_a.get(), small_b.get(), &(this->small_e_vals), &work); + + // On exit, the eigenvectors will be stored in the rows of the A matrix. + // We create submatrices for the vectors to check that + // A * x = lambda * B * x for each vector. + for (gko::size_type i = 0; i < this->small_e_vals.get_size(); i++) { + auto evec = gko::share(Mtx::create( + this->exec, gko::dim<2>{this->small_e_vals.get_size(), 1}, + Ary::view( + this->exec, this->small_e_vals.get_size(), + small_a->get_values() + i * this->small_e_vals.get_size()), + 1)); + + auto lambda_r = gko::share(Mtx_r::create( + this->exec, gko::dim<2>{1, 1}, + Ary_r::view(this->exec, 1, this->small_e_vals.get_data() + i), 1)); + std::shared_ptr lambda; + if constexpr (gko::is_complex_s::value) { + lambda = lambda_r->make_complex(); + } else { + lambda = lambda_r; + } + // A * x = lambda * B * x; + auto a_x = Mtx::create(this->exec, + gko::dim<2>{this->small_e_vals.get_size(), 1}); + auto lambda_b_x = Mtx::create( + this->exec, gko::dim<2>{this->small_e_vals.get_size(), 1}); + lambda_b_x->fill(gko::zero()); + auto one = gko::initialize({gko::one()}, this->exec); + // a_x = A * x + small_a_copy->apply(evec, a_x); + // lambda_b_x = lambda * B * x + small_b_copy->apply(lambda, evec, one, lambda_b_x); + + GKO_ASSERT_MTX_NEAR(a_x, lambda_b_x, r::value); + } +} + + +TYPED_TEST(Lobpcg, KernelBOrthonormalize) +{ + using Mtx = typename TestFixture::Mtx; + using value_type = typename TestFixture::value_type; + using CsrMtx = gko::matrix::Csr; + + auto work = gko::array(this->exec, 1); + std::shared_ptr small_a; + // Test with two kinds of B operator: Identity, and a Csr matrix + auto id = gko::matrix::Identity::create( + this->exec, this->small_a_r->get_size()[0]); + std::shared_ptr small_b_csr = + gko::share(CsrMtx::create(this->exec, this->small_a_r->get_size())); + // Create rectangular submatrix for testing + if constexpr (gko::is_complex_s::value) { + small_a = this->small_a_cmplx->create_submatrix( + gko::span{0, this->small_a_cmplx->get_size()[0]}, + gko::span{0, this->small_a_cmplx->get_size()[0] - 1}); + this->small_b_cmplx->convert_to(small_b_csr); + } else { + small_a = this->small_a_r->create_submatrix( + gko::span{0, this->small_a_r->get_size()[0]}, + gko::span{0, this->small_a_cmplx->get_size()[0] - 1}); + this->small_b_r->convert_to(small_b_csr); + } + auto small_a_copy = gko::clone(small_a); + + // First, test with Identity operator as B + gko::kernels::reference::lobpcg::b_orthonormalize(this->exec, small_a.get(), + id.get(), &work); + // On exit, small_a should now be orthonormalized, + // i.e., small_a^H * small_a = I. + auto aH_a = Mtx::create(this->exec, gko::dim<2>{small_a->get_size()[1], + small_a->get_size()[1]}); + auto after_ortho_H = gko::as(small_a->conj_transpose()); + after_ortho_H->apply(small_a, aH_a); + // Check if applying aH_a to the orthonormalized a^H leaves it unchanged + auto result = Mtx::create(this->exec, after_ortho_H->get_size()); + aH_a->apply(after_ortho_H, result); + GKO_ASSERT_MTX_NEAR(result, after_ortho_H, r::value); + + // Now, test with Csr matrix operator as B + gko::kernels::reference::lobpcg::b_orthonormalize( + this->exec, small_a_copy.get(), small_b_csr.get(), &work); + // On exit, small_a_copy should now be B-orthonormalized, + // i.e., small_a_copy^H * small_b_csr * small_a_copy = I. + auto b_a = Mtx::create( + this->exec, + gko::dim<2>{small_b_csr->get_size()[0], small_a_copy->get_size()[1]}); + small_b_csr->apply(small_a_copy, b_a); + auto aH_b_a = Mtx::create( + this->exec, + gko::dim<2>{small_a_copy->get_size()[1], small_a_copy->get_size()[1]}); + auto after_b_ortho_H = gko::as(small_a_copy->conj_transpose()); + after_b_ortho_H->apply(b_a, aH_b_a); + // Check if applying aH_b_a to the B-orthonormalized a^H leaves it unchanged + result = Mtx::create(this->exec, after_b_ortho_H->get_size()); + aH_b_a->apply(after_b_ortho_H, result); + GKO_ASSERT_MTX_NEAR(result, after_b_ortho_H, r::value); +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6e72dbdf0aa..df9c9308dec 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -3,6 +3,9 @@ include(${PROJECT_SOURCE_DIR}/cmake/create_test.cmake) add_subdirectory(base) add_subdirectory(components) add_subdirectory(distributed) +if(GINKGO_BUILD_LAPACK) + add_subdirectory(eigensolver) +endif() add_subdirectory(factorization) add_subdirectory(log) add_subdirectory(matrix) diff --git a/test/eigensolver/CMakeLists.txt b/test/eigensolver/CMakeLists.txt new file mode 100644 index 00000000000..4244776380e --- /dev/null +++ b/test/eigensolver/CMakeLists.txt @@ -0,0 +1 @@ +ginkgo_create_common_test(lobpcg_kernels DISABLE_EXECUTORS omp dpcpp) diff --git a/test/eigensolver/lobpcg_kernels.cpp b/test/eigensolver/lobpcg_kernels.cpp new file mode 100644 index 00000000000..2da37d8faa4 --- /dev/null +++ b/test/eigensolver/lobpcg_kernels.cpp @@ -0,0 +1,274 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "core/eigensolver/lobpcg_kernels.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "core/test/utils.hpp" +#include "test/utils/common_fixture.hpp" + + +template +class Lobpcg : public CommonTestFixture { +protected: + using value_type = ValueType; + using rc_value_type = gko::remove_complex; + using cmplx_value_type = gko::to_complex; + using Mtx_r = gko::matrix::Dense; + using Mtx_c = gko::to_complex; + using Mtx = gko::matrix::Dense; + using Ary_r = gko::array; + using Ary = gko::array; + Lobpcg() + { + small_a_r = gko::initialize({{13.2, -4.3, -1.8, 0.12}, + {-4.3, 24.2, -1.7, -2.3}, + {-1.8, -1.7, 18.7, -0.8}, + {0.12, -2.3, -0.8, 10.0}}, + ref); + small_b_r = gko::initialize({{2.0, -1.1, 0.3, 0.01}, + {-1.1, 2.5, -0.8, 0.5}, + {0.3, -0.8, 2.3, -0.2}, + {0.01, 0.5, -0.2, 1.9}}, + ref); + small_a_cmplx = gko::initialize( + {{cmplx_value_type{13.2, 0.0}, cmplx_value_type{-4.3, -0.3}, + cmplx_value_type{-1.8, 1.12}, cmplx_value_type{0.12, 0.6}}, + {cmplx_value_type{-4.3, 0.3}, cmplx_value_type{24.2, 0.0}, + cmplx_value_type{-1.7, -2.2}, cmplx_value_type{-2.3, -0.55}}, + {cmplx_value_type{-1.8, -1.12}, cmplx_value_type{-1.7, 2.2}, + cmplx_value_type{18.7, 0.0}, cmplx_value_type{-0.8, -1.18}}, + {cmplx_value_type{0.12, -0.6}, cmplx_value_type{-2.3, 0.55}, + cmplx_value_type{-0.8, 1.18}, cmplx_value_type{10.0, 0.0}}}, + ref); + small_b_cmplx = gko::initialize( + {{cmplx_value_type{2.0, 0.0}, cmplx_value_type{-1.1, -0.1}, + cmplx_value_type{0.3, 0.12}, cmplx_value_type{0.01, 0.4}}, + {cmplx_value_type{-1.1, 0.1}, cmplx_value_type{2.5, 0.0}, + cmplx_value_type{-0.8, -0.18}, cmplx_value_type{0.5, -0.097}}, + {cmplx_value_type{0.3, -0.12}, cmplx_value_type{-0.8, 0.18}, + cmplx_value_type{2.3, 0.0}, cmplx_value_type{-0.2, -0.172}}, + {cmplx_value_type{0.01, -0.4}, cmplx_value_type{0.5, 0.097}, + cmplx_value_type{-0.2, 0.172}, cmplx_value_type{1.9, 0.0}}}, + ref); + small_e_vals = Ary_r(ref, 4); + + d_small_a_r = gko::clone(exec, small_a_r); + d_small_b_r = gko::clone(exec, small_b_r); + d_small_a_cmplx = gko::clone(exec, small_a_cmplx); + d_small_b_cmplx = gko::clone(exec, small_b_cmplx); + d_small_e_vals = Ary_r(exec, 4); + } + + std::shared_ptr small_a_r; + std::shared_ptr small_b_r; + std::shared_ptr small_a_cmplx; + std::shared_ptr small_b_cmplx; + std::shared_ptr small_a; + std::shared_ptr small_b; + Ary_r small_e_vals; + + std::shared_ptr d_small_a_r; + std::shared_ptr d_small_b_r; + std::shared_ptr d_small_a_cmplx; + std::shared_ptr d_small_b_cmplx; + std::shared_ptr d_small_a; + std::shared_ptr d_small_b; + Ary_r d_small_e_vals; +}; + +TYPED_TEST_SUITE(Lobpcg, gko::test::ValueTypesBase, TypenameNameGenerator); + + +TYPED_TEST(Lobpcg, KernelSymmEigIsEquivalentToRef) +{ + using Mtx_r = typename TestFixture::Mtx_r; + using Mtx = typename TestFixture::Mtx; + using Ary_r = typename TestFixture::Ary_r; + using Ary = typename TestFixture::Ary; + using value_type = typename TestFixture::value_type; + + auto refwork = gko::array(this->ref, 1); + auto d_work = gko::array(this->exec, 1); + + if constexpr (gko::is_complex_s::value) { + this->small_a = this->small_a_cmplx; + this->d_small_a = this->d_small_a_cmplx; + } else { + this->small_a = this->small_a_r; + this->d_small_a = this->d_small_a_r; + } + + gko::kernels::reference::lobpcg::symm_eig(this->ref, this->small_a.get(), + &(this->small_e_vals), &refwork); + gko::kernels::GKO_DEVICE_NAMESPACE::lobpcg::symm_eig( + this->exec, this->d_small_a.get(), &(this->d_small_e_vals), &d_work); + + const double tol = 10 * r::value; + GKO_ASSERT_ARRAY_NEAR(this->d_small_e_vals, this->small_e_vals, tol); + + // The eigenvectors may differ by a factor of -1 between libraries. + // Check for this and adjust before comparing output matrices. + for (gko::size_type i = 0; i < this->small_e_vals.get_size(); i++) { + auto evec = gko::share(Mtx::create( + this->ref, gko::dim<2>{this->small_e_vals.get_size(), 1}, + Ary::view(this->ref, this->small_e_vals.get_size(), + this->small_a->get_values() + + i * this->small_e_vals.get_size()), + 1)); + value_type evec_first_entry = evec->at(0, 0); + + auto d_evec_start = gko::share( + Mtx::create(this->exec, gko::dim<2>{1, 1}, + Ary::view(this->exec, 1, + this->d_small_a->get_values() + + i * this->d_small_e_vals.get_size()), + 1)); + value_type d_evec_first_entry; + this->ref->copy_from(this->exec, 1, d_evec_start->get_values(), + &d_evec_first_entry); + + auto neg_one = + gko::initialize({-gko::one()}, this->exec); + if (gko::abs(evec_first_entry / d_evec_first_entry + + gko::one()) < tol) { + evec->scale(neg_one); + } + } + GKO_ASSERT_MTX_NEAR(this->d_small_a, this->small_a, tol); +} + + +TYPED_TEST(Lobpcg, KernelSymmGeneralizedEigIsEquivalentToRef) +{ + using Mtx_r = typename TestFixture::Mtx_r; + using Mtx = typename TestFixture::Mtx; + using Ary_r = typename TestFixture::Ary_r; + using Ary = typename TestFixture::Ary; + using value_type = typename TestFixture::value_type; + + auto refwork = gko::array(this->ref, 1); + auto d_work = gko::array(this->exec, 1); + + if constexpr (gko::is_complex_s::value) { + this->small_a = this->small_a_cmplx; + this->small_b = this->small_b_cmplx; + this->d_small_a = this->d_small_a_cmplx; + this->d_small_b = this->d_small_b_cmplx; + } else { + this->small_a = this->small_a_r; + this->small_b = this->small_b_r; + this->d_small_a = this->d_small_a_r; + this->d_small_b = this->d_small_b_r; + } + + gko::kernels::reference::lobpcg::symm_generalized_eig( + this->ref, this->small_a.get(), this->small_b.get(), + &(this->small_e_vals), &refwork); + gko::kernels::GKO_DEVICE_NAMESPACE::lobpcg::symm_generalized_eig( + this->exec, this->d_small_a.get(), this->d_small_b.get(), + &(this->d_small_e_vals), &d_work); + + const double tol = 10 * r::value; + GKO_ASSERT_ARRAY_NEAR(this->d_small_e_vals, this->small_e_vals, tol); + + // The eigenvectors may differ by a factor of -1 between libraries. + // Check for this and adjust before comparing output matrices. + for (gko::size_type i = 0; i < this->small_e_vals.get_size(); i++) { + auto evec = gko::share(Mtx::create( + this->ref, gko::dim<2>{this->small_e_vals.get_size(), 1}, + Ary::view(this->ref, this->small_e_vals.get_size(), + this->small_a->get_values() + + i * this->small_e_vals.get_size()), + 1)); + value_type evec_first_entry = evec->at(0, 0); + + auto d_evec_start = gko::share( + Mtx::create(this->exec, gko::dim<2>{1, 1}, + Ary::view(this->exec, 1, + this->d_small_a->get_values() + + i * this->d_small_e_vals.get_size()), + 1)); + value_type d_evec_first_entry; + this->ref->copy_from(this->exec, 1, d_evec_start->get_values(), + &d_evec_first_entry); + + auto neg_one = + gko::initialize({-gko::one()}, this->exec); + if (gko::abs(evec_first_entry / d_evec_first_entry + + gko::one()) < tol) { + evec->scale(neg_one); + } + } + GKO_ASSERT_MTX_NEAR(this->d_small_a, this->small_a, tol); +} + + +TYPED_TEST(Lobpcg, KernelBOrthonormalizeIsEquivalentToRef) +{ + using Mtx = typename TestFixture::Mtx; + using value_type = typename TestFixture::value_type; + using CsrMtx = gko::matrix::Csr; + + auto refwork = gko::array(this->ref, 1); + auto d_work = gko::array(this->exec, 1); + + std::shared_ptr small_a; + std::shared_ptr d_small_a; + + // Test with two kinds of B operator: Identity, and a Csr matrix + auto id = gko::matrix::Identity::create( + this->ref, this->small_a_r->get_size()[0]); + auto d_id = gko::matrix::Identity::create( + this->exec, this->small_a_r->get_size()[0]); + std::shared_ptr small_b_csr = + gko::share(CsrMtx::create(this->ref, this->small_a_r->get_size())); + std::shared_ptr d_small_b_csr = + gko::share(CsrMtx::create(this->exec, this->small_a_r->get_size())); + // Create rectangular submatrix for testing + if constexpr (gko::is_complex_s::value) { + small_a = this->small_a_cmplx->create_submatrix( + gko::span{0, this->small_a_cmplx->get_size()[0]}, + gko::span{0, this->small_a_cmplx->get_size()[0] - 1}); + this->small_b_cmplx->convert_to(small_b_csr); + d_small_a = this->d_small_a_cmplx->create_submatrix( + gko::span{0, this->d_small_a_cmplx->get_size()[0]}, + gko::span{0, this->small_a_cmplx->get_size()[0] - 1}); + this->d_small_b_cmplx->convert_to(d_small_b_csr); + } else { + small_a = this->small_a_r->create_submatrix( + gko::span{0, this->small_a_r->get_size()[0]}, + gko::span{0, this->small_a_cmplx->get_size()[0] - 1}); + this->small_b_r->convert_to(small_b_csr); + d_small_a = this->d_small_a_r->create_submatrix( + gko::span{0, this->d_small_a_r->get_size()[0]}, + gko::span{0, this->small_a_cmplx->get_size()[0] - 1}); + this->d_small_b_r->convert_to(d_small_b_csr); + } + auto small_a_copy = gko::clone(small_a); + auto d_small_a_copy = gko::clone(d_small_a); + + // First, test with Identity operator as B + gko::kernels::reference::lobpcg::b_orthonormalize(this->ref, small_a.get(), + id.get(), &refwork); + gko::kernels::GKO_DEVICE_NAMESPACE::lobpcg::b_orthonormalize( + this->exec, d_small_a.get(), d_id.get(), &d_work); + GKO_ASSERT_MTX_NEAR(small_a, d_small_a, r::value); + + // Now, test with Csr matrix operator as B + gko::kernels::reference::lobpcg::b_orthonormalize( + this->ref, small_a_copy.get(), small_b_csr.get(), &refwork); + gko::kernels::GKO_DEVICE_NAMESPACE::lobpcg::b_orthonormalize( + this->exec, d_small_a_copy.get(), d_small_b_csr.get(), &d_work); + GKO_ASSERT_MTX_NEAR(small_a_copy, d_small_a_copy, r::value); +}