Skip to content

Commit 5217d60

Browse files
committed
Add BLAS bindings to Reference
1 parent b0f9764 commit 5217d60

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

reference/base/blas_bindings.hpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#ifndef GKO_REFERENCE_BASE_BLAS_BINDINGS_HPP_
6+
#define GKO_REFERENCE_BASE_BLAS_BINDINGS_HPP_
7+
8+
#include <ginkgo/core/base/types.hpp>
9+
10+
11+
#if GKO_HAVE_LAPACK
12+
13+
14+
extern "C" {
15+
16+
17+
// Triangular matrix-matrix multiplication
18+
void strmm(const char* side, const char* uplo, const char* transa,
19+
const char* diag, const std::int32_t* m, const std::int32_t* n,
20+
const float* alpha, const float* A, const std::int32_t* lda,
21+
float* B, const std::int32_t* ldb);
22+
23+
void dtrmm(const char* side, const char* uplo, const char* transa,
24+
const char* diag, const std::int32_t* m, const std::int32_t* n,
25+
const double* alpha, const double* A, const std::int32_t* lda,
26+
double* B, const std::int32_t* ldb);
27+
28+
void ctrmm(const char* side, const char* uplo, const char* transa,
29+
const char* diag, const std::int32_t* m, const std::int32_t* n,
30+
const std::complex<float>* alpha, const std::complex<float>* A,
31+
const std::int32_t* lda, std::complex<float>* B,
32+
const std::int32_t* ldb);
33+
34+
void ztrmm(const char* side, const char* uplo, const char* transa,
35+
const char* diag, const std::int32_t* m, const std::int32_t* n,
36+
const std::complex<double>* alpha, const std::complex<double>* A,
37+
const std::int32_t* lda, std::complex<double>* B,
38+
const std::int32_t* ldb);
39+
}
40+
41+
42+
namespace gko {
43+
namespace kernels {
44+
namespace reference {
45+
/**
46+
* @brief The BLAS namespace.
47+
*
48+
* @ingroup lapack
49+
*/
50+
namespace blas {
51+
52+
53+
template <typename ValueType>
54+
struct is_supported : std::false_type {};
55+
56+
template <>
57+
struct is_supported<float> : std::true_type {};
58+
59+
template <>
60+
struct is_supported<double> : std::true_type {};
61+
62+
template <>
63+
struct is_supported<std::complex<float>> : std::true_type {};
64+
65+
template <>
66+
struct is_supported<std::complex<double>> : std::true_type {};
67+
68+
69+
#define GKO_BIND_TRMM(ValueType, BlasName) \
70+
inline void trmm(const char* side, const char* uplo, const char* transa, \
71+
const char* diag, const int32* m, const int32* n, \
72+
const ValueType* alpha, const ValueType* a, \
73+
const int32* lda, ValueType* b, const int32* ldb) \
74+
{ \
75+
BlasName(side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); \
76+
} \
77+
static_assert(true, \
78+
"This assert is used to counter the false positive extra " \
79+
"semi-colon warnings")
80+
81+
GKO_BIND_TRMM(float, strmm);
82+
GKO_BIND_TRMM(double, dtrmm);
83+
GKO_BIND_TRMM(std::complex<float>, ctrmm);
84+
GKO_BIND_TRMM(std::complex<double>, ztrmm);
85+
template <typename ValueType>
86+
inline void trmm(const char* side, const char* uplo, const char* transa,
87+
const char* diag, const int32* m, const int32* n,
88+
const ValueType* alpha, const ValueType* a, const int32* lda,
89+
ValueType* b, const int32* ldb) GKO_NOT_IMPLEMENTED;
90+
91+
#undef GKO_BIND_TRMM
92+
93+
94+
#define BLAS_OP_N 'N'
95+
#define BLAS_OP_T 'T'
96+
#define BLAS_OP_C 'C'
97+
98+
#define BLAS_SIDE_LEFT 'L'
99+
#define BLAS_SIDE_RIGHT 'R'
100+
101+
102+
} // namespace blas
103+
} // namespace reference
104+
} // namespace kernels
105+
} // namespace gko
106+
107+
#endif // GKO_HAVE_LAPACK
108+
109+
#endif // GKO_REFERENCE_BASE_BLAS_BINDINGS_HPP_

0 commit comments

Comments
 (0)