Skip to content

Commit 2f0babf

Browse files
committed
Merge branch 'ChASE-v1.3-newIO' into 'master'
Update the estimation of bounds for QR See merge request SLai/ChASE!28
2 parents 670dad3 + ee7b8e3 commit 2f0babf

File tree

8 files changed

+462
-16
lines changed

8 files changed

+462
-16
lines changed

ChASE-MPI/blas_fortran.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,28 @@ extern "C"
352352
const BlasInt* lda, const dcomplex* b,
353353
const BlasInt* ldb);
354354

355+
void FC_GLOBAL(sgesvd, SGESVD)(const char *jobu, const char *jobvt,
356+
const BlasInt* m, const BlasInt* n,
357+
float *A, const BlasInt* lda, float *S,
358+
float *U, const BlasInt *ldu, float *Vt,
359+
const BlasInt *ldvt, float *work,
360+
const BlasInt *lwork, float *rwork, BlasInt *info );
361+
void FC_GLOBAL(dgesvd, DGESVD)(const char *jobu, const char *jobvt,
362+
const BlasInt* m, const BlasInt* n,
363+
double *A, const BlasInt* lda, double *S,
364+
double *U, const BlasInt *ldu, double *Vt,
365+
const BlasInt *ldvt, double *work,
366+
const BlasInt *lwork, double *rwork, BlasInt *info );
367+
void FC_GLOBAL(cgesvd, CGESVD)(const char *jobu, const char *jobvt, const BlasInt* m,
368+
const BlasInt* n, scomplex *A, const BlasInt* lda,
369+
float *S, scomplex *U, const BlasInt *ldu, scomplex *Vt,
370+
const BlasInt *ldvt, scomplex *work, const BlasInt *lwork,
371+
float *rwork, BlasInt *info );
372+
void FC_GLOBAL(zgesvd, ZGESVD)(const char *jobu, const char *jobvt, const BlasInt* m,
373+
const BlasInt* n, dcomplex *A, const BlasInt* lda, double *S,
374+
dcomplex *U, const BlasInt *ldu, dcomplex *Vt,
375+
const BlasInt *ldvt, dcomplex *work, const BlasInt *lwork,
376+
double *rwork, BlasInt *info );
355377
} // extern "C"
356378
} // namespace mpi
357379
} // namespace chase

ChASE-MPI/blas_templates.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ void t_trsm(const char side, const char uplo, const char trans, const char diag,
108108
const T* a, const std::size_t lda, const T* b,
109109
const std::size_t ldb);
110110

111+
template<typename T>
112+
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n,
113+
T *A, const std::size_t lda, Base<T> *S, T *U, const std::size_t ldu, T *Vt,
114+
const std::size_t ldvt);
111115
// scalapack
112116
// BLACS
113117
void t_descinit(std::size_t* desc, std::size_t* m, std::size_t* n,

ChASE-MPI/blas_templates.inc

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,6 +1349,123 @@ void t_trsm(const char side, const char uplo, const char trans, const char diag,
13491349
(&side, &uplo, &trans, &diag, &m_, &n_, alpha, a, &lda_, b, &ldb_);
13501350
}
13511351

1352+
template<>
1353+
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n, float *A,
1354+
const std::size_t lda, float *S, float *U, const std::size_t ldu, float *Vt, const std::size_t ldvt){
1355+
using T = std::remove_reference<decltype((A[0]))>::type;
1356+
BlasInt m_ = m;
1357+
BlasInt n_ = n;
1358+
BlasInt lda_ = lda;
1359+
BlasInt ldu_ = ldu;
1360+
BlasInt ldvt_ = ldvt;
1361+
1362+
T* work;
1363+
Base<T> *rwork = new Base<T>[5 * std::min(m, n)];
1364+
T numwork;
1365+
BlasInt lwork, info;
1366+
1367+
lwork = -1;
1368+
FC_GLOBAL(sgesvd, SGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, &numwork, &lwork, rwork, &info);
1369+
assert(info == 0);
1370+
1371+
1372+
lwork = static_cast<std::size_t>((numwork));
1373+
auto ptr = std::unique_ptr<T[]>{new T[lwork]};
1374+
work = ptr.get();
1375+
1376+
FC_GLOBAL(sgesvd, SGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, work, &lwork, rwork, &info);
1377+
assert(info == 0);
1378+
}
1379+
1380+
template<>
1381+
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n, double *A,
1382+
const std::size_t lda, double *S, double *U, const std::size_t ldu, double *Vt, const std::size_t ldvt){
1383+
using T = std::remove_reference<decltype((A[0]))>::type;
1384+
BlasInt m_ = m;
1385+
BlasInt n_ = n;
1386+
BlasInt lda_ = lda;
1387+
BlasInt ldu_ = ldu;
1388+
BlasInt ldvt_ = ldvt;
1389+
1390+
T* work;
1391+
Base<T> *rwork = new Base<T>[5 * std::min(m, n)];
1392+
T numwork;
1393+
BlasInt lwork, info;
1394+
1395+
lwork = -1;
1396+
FC_GLOBAL(dgesvd, DGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, &numwork, &lwork, rwork, &info);
1397+
assert(info == 0);
1398+
1399+
1400+
lwork = static_cast<std::size_t>((numwork));
1401+
auto ptr = std::unique_ptr<T[]>{new T[lwork]};
1402+
work = ptr.get();
1403+
1404+
FC_GLOBAL(dgesvd, DGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, work, &lwork, rwork, &info);
1405+
assert(info == 0);
1406+
1407+
}
1408+
1409+
template<>
1410+
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n, std::complex<float> *A,
1411+
const std::size_t lda, float *S, std::complex<float> *U, const std::size_t ldu, std::complex<float> *Vt, const std::size_t ldvt){
1412+
using T = std::remove_reference<decltype((A[0]))>::type;
1413+
BlasInt m_ = m;
1414+
BlasInt n_ = n;
1415+
BlasInt lda_ = lda;
1416+
BlasInt ldu_ = ldu;
1417+
BlasInt ldvt_ = ldvt;
1418+
1419+
T* work;
1420+
Base<T> *rwork = new Base<T>[5 * std::min(m, n)];
1421+
T numwork;
1422+
BlasInt lwork, info;
1423+
1424+
lwork = -1;
1425+
FC_GLOBAL(cgesvd, CGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, &numwork, &lwork, rwork, &info);
1426+
assert(info == 0);
1427+
1428+
1429+
lwork = static_cast<std::size_t>(real(numwork));
1430+
auto ptr = std::unique_ptr<T[]>{new T[lwork]};
1431+
work = ptr.get();
1432+
1433+
FC_GLOBAL(cgesvd, CGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, work, &lwork, rwork, &info);
1434+
assert(info == 0);
1435+
}
1436+
1437+
template<>
1438+
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n, std::complex<double> *A,
1439+
const std::size_t lda, double *S, std::complex<double> *U, const std::size_t ldu, std::complex<double> *Vt, const std::size_t ldvt){
1440+
using T = std::remove_reference<decltype((A[0]))>::type;
1441+
BlasInt m_ = m;
1442+
BlasInt n_ = n;
1443+
BlasInt lda_ = lda;
1444+
BlasInt ldu_ = ldu;
1445+
BlasInt ldvt_ = ldvt;
1446+
1447+
T* work;
1448+
Base<T> *rwork = new Base<T>[5 * std::min(m, n)];
1449+
T numwork;
1450+
BlasInt lwork, info;
1451+
1452+
lwork = -1;
1453+
FC_GLOBAL(zgesvd, ZGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, &numwork, &lwork, rwork, &info);
1454+
assert(info == 0);
1455+
1456+
1457+
lwork = static_cast<std::size_t>(real(numwork));
1458+
auto ptr = std::unique_ptr<T[]>{new T[lwork]};
1459+
work = ptr.get();
1460+
1461+
FC_GLOBAL(zgesvd, ZGESVD)(&jobu, &jobvt, &m_, &n_, A, &lda_, S, U, &ldu_, Vt, &ldvt_, work, &lwork, rwork, &info);
1462+
assert(info == 0);
1463+
1464+
delete[] rwork;
1465+
1466+
}
1467+
1468+
13521469
#if defined(HAS_SCALAPACK)
13531470
// SCALAPACK
13541471
void t_descinit(std::size_t* desc, std::size_t* m, std::size_t* n,

0 commit comments

Comments
 (0)