Skip to content

Commit a0bcf72

Browse files
authored
Fix mxr, use estimate for every reference (#7036)
1 parent 3f9e462 commit a0bcf72

File tree

5 files changed

+31
-28
lines changed

5 files changed

+31
-28
lines changed

source/source_hamilt/module_ewald/H_Ewald_pw.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@ int H_Ewald_pw::mxr = 200;
1212
H_Ewald_pw::H_Ewald_pw(){};
1313
H_Ewald_pw::~H_Ewald_pw(){};
1414

15+
int H_Ewald_pw::estimate_mxr(const double &rmax, const ModuleBase::Matrix3 &bg)
16+
{
17+
double bg1[3];
18+
bg1[0] = bg.e11; bg1[1] = bg.e12; bg1[2] = bg.e13;
19+
const int nm1 = (int)(dnrm2(3, bg1, 1) * rmax + 2);
20+
bg1[0] = bg.e21; bg1[1] = bg.e22; bg1[2] = bg.e23;
21+
const int nm2 = (int)(dnrm2(3, bg1, 1) * rmax + 2);
22+
bg1[0] = bg.e31; bg1[1] = bg.e32; bg1[2] = bg.e33;
23+
const int nm3 = (int)(dnrm2(3, bg1, 1) * rmax + 2);
24+
return (2 * nm1 + 1) * (2 * nm2 + 1) * (2 * nm3 + 1);
25+
}
26+
1527
double H_Ewald_pw::compute_ewald(const UnitCell& cell,
1628
const ModulePW::PW_Basis* rho_basis,
1729
const ModuleBase::ComplexMatrix& strucFac)
@@ -150,16 +162,7 @@ double H_Ewald_pw::compute_ewald(const UnitCell& cell,
150162
// Compute rmax and dynamically determine mxr (maximum number of r-vectors)
151163
// to avoid buffer overflow for very small unit cells or high cutoff energies.
152164
rmax = 4.0 / sqrt(alpha) / cell.lat0;
153-
{
154-
double bg1[3];
155-
bg1[0] = cell.G.e11; bg1[1] = cell.G.e12; bg1[2] = cell.G.e13;
156-
int nm1 = (int)(dnrm2(3, bg1, 1) * rmax + 2);
157-
bg1[0] = cell.G.e21; bg1[1] = cell.G.e22; bg1[2] = cell.G.e23;
158-
int nm2 = (int)(dnrm2(3, bg1, 1) * rmax + 2);
159-
bg1[0] = cell.G.e31; bg1[1] = cell.G.e32; bg1[2] = cell.G.e33;
160-
int nm3 = (int)(dnrm2(3, bg1, 1) * rmax + 2);
161-
mxr = (2 * nm1 + 1) * (2 * nm2 + 1) * (2 * nm3 + 1);
162-
}
165+
mxr = H_Ewald_pw::estimate_mxr(rmax, cell.G);
163166

164167
if(PARAM.inp.test_energy)
165168
{
@@ -205,7 +208,7 @@ double H_Ewald_pw::compute_ewald(const UnitCell& cell,
205208
// calculate tau[na1]-tau[na2]
206209
dtau = cell.atoms[it1].tau[ia1] - cell.atoms[it2].tau[ia2];
207210
// generates nearest-neighbors shells
208-
H_Ewald_pw::rgen(dtau, rmax, irr, cell.latvec, cell.G, r, r2, nrm);
211+
H_Ewald_pw::rgen(dtau, rmax, irr, cell.latvec, cell.G, r, r2, mxr, nrm);
209212
// at-->cell.latvec, bg-->G
210213
// and sum to the real space part
211214

@@ -249,7 +252,7 @@ double H_Ewald_pw::compute_ewald(const UnitCell& cell,
249252
//calculate tau[na]-tau[nb]
250253
dtau = cell.atoms[nt1].tau[na] - cell.atoms[nt2].tau[nb];
251254
//generates nearest-neighbors shells
252-
H_Ewald_pw::rgen(dtau, rmax, irr, cell.latvec, cell.G, r, r2, nrm);
255+
H_Ewald_pw::rgen(dtau, rmax, irr, cell.latvec, cell.G, r, r2, mxr, nrm);
253256
// at-->cell.latvec, bg-->G
254257
// and sum to the real space part
255258

@@ -301,6 +304,7 @@ void H_Ewald_pw::rgen(
301304
const ModuleBase::Matrix3 &G,
302305
ModuleBase::Vector3<double> *r,
303306
double *r2,
307+
const int mxr,
304308
int &nrm)
305309
{
306310
//-------------------------------------------------------------------

source/source_hamilt/module_ewald/H_Ewald_pw.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class H_Ewald_pw
2020
const ModuleBase::ComplexMatrix& strucFac);
2121

2222
public:
23+
static int estimate_mxr(const double &rmax, const ModuleBase::Matrix3 &bg);
24+
2325
static void rgen(
2426
const ModuleBase::Vector3<double> &dtau,
2527
const double &rmax,
@@ -28,6 +30,7 @@ class H_Ewald_pw
2830
const ModuleBase::Matrix3 &bg,
2931
ModuleBase::Vector3<double> *r,
3032
double *r2,
33+
const int mxr,
3134
int &nrm
3235
);
3336

source/source_hamilt/test/rgen_test.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@ TEST_F(RgenTest, ZeroRmax)
4343
{
4444
// When rmax==0 the function should return immediately with nrm=0
4545
const int mxr_test = 10;
46-
H_Ewald_pw::mxr = mxr_test;
4746
std::vector<ModuleBase::Vector3<double>> r(mxr_test);
4847
std::vector<double> r2(mxr_test);
4948
std::vector<int> irr(mxr_test);
5049
int nrm = 0;
5150

52-
H_Ewald_pw::rgen(dtau, 0.0, irr.data(), latvec, G, r.data(), r2.data(), nrm);
51+
H_Ewald_pw::rgen(dtau, 0.0, irr.data(), latvec, G, r.data(), r2.data(), mxr_test, nrm);
5352

5453
EXPECT_EQ(nrm, 0);
5554
}
@@ -60,13 +59,12 @@ TEST_F(RgenTest, SimpleCubicNearestNeighbors)
6059
// neighbors: 6 + 12 = 18 vectors total.
6160
const double rmax = 1.5;
6261
const int mxr_test = 50;
63-
H_Ewald_pw::mxr = mxr_test;
6462
std::vector<ModuleBase::Vector3<double>> r(mxr_test);
6563
std::vector<double> r2(mxr_test);
6664
std::vector<int> irr(mxr_test);
6765
int nrm = 0;
6866

69-
H_Ewald_pw::rgen(dtau, rmax, irr.data(), latvec, G, r.data(), r2.data(), nrm);
67+
H_Ewald_pw::rgen(dtau, rmax, irr.data(), latvec, G, r.data(), r2.data(), mxr_test, nrm);
7068

7169
EXPECT_EQ(nrm, 18);
7270

@@ -94,14 +92,13 @@ TEST_F(RgenTest, SimpleCubicNonZeroDtau)
9492
// No lattice point coincides with dtau, so neither is excluded.
9593
const double rmax = 0.6;
9694
const int mxr_test = 10;
97-
H_Ewald_pw::mxr = mxr_test;
9895
dtau = ModuleBase::Vector3<double>(0.5, 0.0, 0.0);
9996
std::vector<ModuleBase::Vector3<double>> r(mxr_test);
10097
std::vector<double> r2(mxr_test);
10198
std::vector<int> irr(mxr_test);
10299
int nrm = 0;
103100

104-
H_Ewald_pw::rgen(dtau, rmax, irr.data(), latvec, G, r.data(), r2.data(), nrm);
101+
H_Ewald_pw::rgen(dtau, rmax, irr.data(), latvec, G, r.data(), r2.data(), mxr_test, nrm);
105102

106103
EXPECT_EQ(nrm, 2);
107104
for (int i = 0; i < nrm; ++i)
@@ -128,13 +125,12 @@ TEST_F(RgenTest, LargeRmaxExceedsOriginalLimit)
128125
int nm3 = (int)(dnrm2(3, bg1, 1) * rmax + 2);
129126
const int mxr_test = (2 * nm1 + 1) * (2 * nm2 + 1) * (2 * nm3 + 1);
130127

131-
H_Ewald_pw::mxr = mxr_test;
132128
std::vector<ModuleBase::Vector3<double>> r(mxr_test);
133129
std::vector<double> r2(mxr_test);
134130
std::vector<int> irr(mxr_test);
135131
int nrm = 0;
136132

137-
H_Ewald_pw::rgen(dtau, rmax, irr.data(), latvec, G, r.data(), r2.data(), nrm);
133+
H_Ewald_pw::rgen(dtau, rmax, irr.data(), latvec, G, r.data(), r2.data(), mxr_test, nrm);
138134

139135
// Must exceed the old hard-coded limit that caused the crash
140136
EXPECT_GT(nrm, 200);

source/source_pw/module_pwdft/forces.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
657657
int nrm = 0;
658658

659659
// output of rgen: the number of vectors in the sphere
660-
const int mxr = 200;
660+
const int mxr = H_Ewald_pw::estimate_mxr(rmax, ucell.G);
661661
// the maximum number of R vectors included in r
662662
std::vector<ModuleBase::Vector3<double>> r(mxr);
663663
std::vector<double> r2(mxr);
@@ -681,7 +681,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
681681
{
682682
ModuleBase::Vector3<double> d_tau
683683
= ucell.atoms[T1].tau[I1] - ucell.atoms[T2].tau[I2];
684-
H_Ewald_pw::rgen(d_tau, rmax, irr.data(), ucell.latvec, ucell.G, r.data(), r2.data(), nrm);
684+
H_Ewald_pw::rgen(d_tau, rmax, irr.data(), ucell.latvec, ucell.G, r.data(), r2.data(), mxr, nrm);
685685

686686
for (int n = 0; n < nrm; n++)
687687
{

source/source_pw/module_pwdft/stress_ewa.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ void Stress_Func<FPTYPE, Device>::stress_ewa(const UnitCell& ucell,
108108
}
109109

110110
//R-space sum here (only for the processor that contains G=0)
111-
int mxr = 200;
112111
int *irr=nullptr;
113112
ModuleBase::Vector3<FPTYPE> *r;
114113
FPTYPE *r2=nullptr;
@@ -121,13 +120,14 @@ void Stress_Func<FPTYPE, Device>::stress_ewa(const UnitCell& ucell,
121120

122121
if(ig0 >= 0)
123122
{
124-
std::vector<ModuleBase::Vector3<FPTYPE>> r(mxr);
125-
std::vector<FPTYPE> r2(mxr);
126-
std::vector<int> irr(mxr);
127-
128123
FPTYPE sqa = sqrt(alpha);
129124
FPTYPE sq8a_2pi = sqrt(8 * alpha / (ModuleBase::TWO_PI));
130125
rmax = 4.0/sqa/ucell.lat0;
126+
const int mxr = H_Ewald_pw::estimate_mxr(rmax, ucell.G);
127+
128+
std::vector<ModuleBase::Vector3<FPTYPE>> r(mxr);
129+
std::vector<FPTYPE> r2(mxr);
130+
std::vector<int> irr(mxr);
131131

132132
#pragma omp for
133133
for(long long ijat = 0; ijat < ucell.nat * ucell.nat; ijat++)
@@ -142,7 +142,7 @@ void Stress_Func<FPTYPE, Device>::stress_ewa(const UnitCell& ucell,
142142
//calculate tau[na]-tau[nb]
143143
d_tau = ucell.atoms[it].tau[i] - ucell.atoms[jt].tau[j];
144144
//generates nearest-neighbors shells
145-
H_Ewald_pw::rgen(d_tau, rmax, irr.data(), ucell.latvec, ucell.G, r.data(), r2.data(), nrm);
145+
H_Ewald_pw::rgen(d_tau, rmax, irr.data(), ucell.latvec, ucell.G, r.data(), r2.data(), mxr, nrm);
146146
for(int nr=0; nr<nrm; nr++)
147147
{
148148
rr=sqrt(r2[nr]) * ucell.lat0;

0 commit comments

Comments
 (0)