Skip to content

Commit e251a8c

Browse files
committed
Merge branch 'jgfouca/fix_diag_cmake' into master (PR #7558)
Fix the CreateDiagTest function It was dropping additional arguments Fixes #7554 [BFB]
2 parents 4d8c83c + 75ca319 commit e251a8c

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

components/eamxx/src/diagnostics/tests/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
include(ScreamUtils)
22

3-
function (createDiagTest test_name test_srcs)
4-
CreateUnitTest(${test_name} "${test_srcs}"
3+
function (CreateDiagTest test_name test_srcs)
4+
CreateUnitTest(${test_name} ${test_srcs} ${ARGN}
55
LIBS diagnostics physics_share
66
LABELS diagnostics)
77
endfunction ()
@@ -79,4 +79,4 @@ CreateDiagTest(horiz_avg "horiz_avg_test.cpp")
7979
CreateDiagTest(vert_contract "vert_contract_test.cpp")
8080

8181
# Test zonal averaging
82-
CreateDiagTest(zonal_avg "zonal_avg_test.cpp" MPI_RANKS 1 ${SCREAM_TEST_MAX_RANKS})
82+
CreateDiagTest(zonal_avg zonal_avg_test.cpp MPI_RANKS 1 ${SCREAM_TEST_MAX_RANKS})

components/eamxx/src/diagnostics/tests/zonal_avg_test.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@
77

88
namespace scream {
99

10-
std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols, const int nlevs) {
11-
const int num_global_cols = ncols * comm.size();
10+
std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ngcols, const int nlevs) {
1211

1312
using vos_t = std::vector<std::string>;
1413
ekat::ParameterList gm_params;
1514
gm_params.set("grids_names", vos_t{"Point Grid"});
1615
auto &pl = gm_params.sublist("Point Grid");
1716
pl.set<std::string>("type", "point_grid");
1817
pl.set("aliases", vos_t{"Physics"});
19-
pl.set<int>("number_of_global_columns", num_global_cols);
18+
pl.set<int>("number_of_global_columns", ngcols);
2019
pl.set<int>("number_of_vertical_levels", nlevs);
2120

2221
auto gm = create_mesh_free_grids_manager(comm, gm_params);
@@ -30,7 +29,7 @@ TEST_CASE("zonal_avg") {
3029
using namespace ekat::units;
3130

3231
// A numerical tolerance
33-
auto tol = std::numeric_limits<Real>::epsilon() * 100;
32+
const auto tol = std::numeric_limits<Real>::epsilon() * 100;
3433

3534
// A world comm
3635
ekat::Comm comm(MPI_COMM_WORLD);
@@ -41,9 +40,10 @@ TEST_CASE("zonal_avg") {
4140
// Create a grids manager - single column for these tests
4241
constexpr int nlevs = 3;
4342
constexpr int dim3 = 4;
44-
const int ngcols = 6 * comm.size();
43+
const int ncols = 6;
4544
const int nlats = 4;
4645

46+
const int ngcols = ncols * comm.size();
4747
auto gm = create_gm(comm, ngcols, nlevs);
4848
auto grid = gm->get_grid("Physics");
4949

@@ -56,16 +56,17 @@ TEST_CASE("zonal_avg") {
5656
auto lat_view_h = lat.get_view<Real *, Host>();
5757
const Real lat_delta = sp(180.0) / nlats;
5858
std::vector<Real> zonal_areas(nlats, 0.0);
59-
for (int i = 0; i < ngcols; i++) {
59+
for (int i = 0; i < ncols; i++) {
6060
lat_view_h(i) = sp(-90.0) + (i % nlats + sp(0.5)) * lat_delta;
6161
zonal_areas[i % nlats] += area_view_h[i];
6262
}
6363
lat.sync_to_dev();
64+
comm.all_reduce(zonal_areas.data(), zonal_areas.size(), MPI_SUM);
6465

6566
// Input (randomized) qc
66-
FieldLayout scalar1d_layout{{COL}, {ngcols}};
67-
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}};
68-
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols, dim3, nlevs}};
67+
FieldLayout scalar1d_layout{{COL}, {ncols}};
68+
FieldLayout scalar2d_layout{{COL, LEV}, {ncols, nlevs}};
69+
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ncols, dim3, nlevs}};
6970

7071
FieldIdentifier qc1_id("qc", scalar1d_layout, kg / kg, grid->name());
7172
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid->name());
@@ -135,10 +136,12 @@ TEST_CASE("zonal_avg") {
135136
// calculate the zonal average
136137
auto qc1_view_h = qc1.get_view<const Real *, Host>();
137138
auto diag0_view_h = diag0_field.get_view<Real *, Host>();
138-
for (int i = 0; i < ngcols; i++) {
139+
for (int i = 0; i < ncols; i++) {
139140
const int nlat = i % nlats;
140141
diag0_view_h(nlat) += area_view_h(i) / zonal_areas[nlat] * qc1_view_h(i);
141142
}
143+
comm.all_reduce(diag0_field.template get_internal_view_data<Real, Host>(),
144+
diag0_layout.size(), MPI_SUM);
142145
diag0_field.sync_to_dev();
143146

144147
// Compare
@@ -149,7 +152,7 @@ TEST_CASE("zonal_avg") {
149152
const Real zavg1 = sp(1.0);
150153
qc1.deep_copy(zavg1);
151154
diag1->compute_diagnostic();
152-
auto diag1_view_host = diag1_field.get_view<Real *, Host>();
155+
auto diag1_view_host = diag1_field.get_view<const Real *, Host>();
153156
for (int nlat = 0; nlat < nlats; nlat++) {
154157
REQUIRE_THAT(diag1_view_host(nlat), Catch::Matchers::WithinRel(zavg1, tol));
155158
}
@@ -163,7 +166,7 @@ TEST_CASE("zonal_avg") {
163166
diag2->compute_diagnostic();
164167
auto diag2_field = diag2->get_diagnostic();
165168

166-
auto diag2_view_host = diag2_field.get_view<Real **, Host>();
169+
auto diag2_view_host = diag2_field.get_view<const Real **, Host>();
167170
for (int i = 0; i < nlevs; ++i) {
168171
for (int nlat = 0; nlat < nlats; nlat++) {
169172
REQUIRE_THAT(diag2_view_host(nlat, i), Catch::Matchers::WithinRel(zavg2, tol));
@@ -176,16 +179,18 @@ TEST_CASE("zonal_avg") {
176179
FieldIdentifier diag3m_id("qc_zonal_avg_manual", diag3m_layout, kg / kg, grid->name());
177180
Field diag3m_field(diag3m_id);
178181
diag3m_field.allocate_view();
179-
auto qc3_view_h = qc3.get_view<Real ***, Host>();
182+
auto qc3_view_h = qc3.get_view<const Real ***, Host>();
180183
auto diag3m_view_h = diag3m_field.get_view<Real ***, Host>();
181-
for (int i = 0; i < ngcols; i++) {
184+
for (int i = 0; i < ncols; i++) {
182185
const int nlat = i % nlats;
183186
for (int j = 0; j < dim3; j++) {
184187
for (int k = 0; k < nlevs; k++) {
185188
diag3m_view_h(nlat, j, k) += area_view_h(i) / zonal_areas[nlat] * qc3_view_h(i, j, k);
186189
}
187190
}
188191
}
192+
comm.all_reduce(diag3m_field.template get_internal_view_data<Real, Host>(),
193+
diag3m_layout.size(), MPI_SUM);
189194
diag3m_field.sync_to_dev();
190195
diag3->set_required_field(qc3);
191196
diag3->initialize(t0, RunType::Initial);

0 commit comments

Comments
 (0)