Skip to content

Commit c6fa3aa

Browse files
committed
EAMxx: improve horiz_avg testing
1 parent c243f18 commit c6fa3aa

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

components/eamxx/src/diagnostics/tests/horiz_avg_test.cpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ TEST_CASE("horiz_avg") {
3333
using TeamMember = typename TeamPolicy::member_type;
3434
using KT = ekat::KokkosTypes<DefaultDevice>;
3535
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>;
36+
37+
// A numerical tolerance
38+
auto tol = std::numeric_limits<Real>::epsilon() * 100;
39+
3640
// A world comm
3741
ekat::Comm comm(MPI_COMM_WORLD);
3842

@@ -44,19 +48,17 @@ TEST_CASE("horiz_avg") {
4448
constexpr int dim3 = 4;
4549
const int ngcols = 6 * comm.size();
4650

47-
auto gm1 = create_gm(comm, ngcols, 1);
48-
auto gm2 = create_gm(comm, ngcols, nlevs);
49-
auto grid1 = gm1->get_grid("Physics");
50-
auto grid2 = gm2->get_grid("Physics");
51+
auto gm = create_gm(comm, ngcols, nlevs);
52+
auto grid = gm->get_grid("Physics");
5153

5254
// Input (randomized) qc
5355
FieldLayout scalar1d_layout{{COL}, {ngcols}};
5456
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}};
5557
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols, dim3, nlevs}};
5658

57-
FieldIdentifier qc1_fid("qc", scalar1d_layout, kg / kg, grid1->name());
58-
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid2->name());
59-
FieldIdentifier qc3_fid("qc", scalar3d_layout, kg / kg, grid2->name());
59+
FieldIdentifier qc1_fid("qc", scalar1d_layout, kg / kg, grid->name());
60+
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid->name());
61+
FieldIdentifier qc3_fid("qc", scalar3d_layout, kg / kg, grid->name());
6062

6163
Field qc1(qc1_fid);
6264
Field qc2(qc2_fid);
@@ -78,8 +80,8 @@ TEST_CASE("horiz_avg") {
7880
register_diagnostics();
7981

8082
ekat::ParameterList params;
81-
// REQUIRE_THROWS(diag_factory.create("HorizAvgDiag", comm,
82-
// params)); // No 'field_name' parameter
83+
REQUIRE_THROWS(diag_factory.create("HorizAvgDiag", comm,
84+
params)); // No 'field_name' parameter
8385

8486
// Set time for qc and randomize its values
8587
qc1.get_header().get_tracking().update_time_stamp(t0);
@@ -90,16 +92,16 @@ TEST_CASE("horiz_avg") {
9092
randomize(qc3, engine, pdf);
9193

9294
// Create and set up the diagnostic
93-
params.set("grid_name", grid1->name());
95+
params.set("grid_name", grid->name());
9496
params.set<std::string>("field_name", "qc");
9597
auto diag1 = diag_factory.create("HorizAvgDiag", comm, params);
9698
auto diag2 = diag_factory.create("HorizAvgDiag", comm, params);
9799
auto diag3 = diag_factory.create("HorizAvgDiag", comm, params);
98-
diag1->set_grids(gm1);
99-
diag2->set_grids(gm2);
100-
diag3->set_grids(gm2);
100+
diag1->set_grids(gm);
101+
diag2->set_grids(gm);
102+
diag3->set_grids(gm);
101103

102-
auto area = grid1->get_geometry_data("area");
104+
auto area = grid->get_geometry_data("area");
103105

104106
diag1->set_required_field(qc1);
105107
diag1->initialize(t0, RunType::Initial);
@@ -109,7 +111,7 @@ TEST_CASE("horiz_avg") {
109111

110112
FieldIdentifier diag0_fid("qc_horiz_avg_manual",
111113
scalar1d_layout.clone().strip_dim(COL), kg / kg,
112-
grid1->name());
114+
grid->name());
113115
Field diag0(diag0_fid);
114116
diag0.allocate_view();
115117
auto diag0_v = diag0.get_view<Real>();
@@ -139,7 +141,9 @@ TEST_CASE("horiz_avg") {
139141
Kokkos::deep_copy(qc1_v, wavg);
140142
diag1->compute_diagnostic();
141143
auto diag1_v2_host = diag1_f.get_view<Real, Host>();
142-
REQUIRE(std::abs(diag1_v2_host() - wavg) < sp(1e-6));
144+
REQUIRE_THAT(diag1_v2_host(),
145+
Catch::Matchers::WithinRel(
146+
wavg, tol)); // Catch2's floating point comparison
143147

144148
// other diags
145149
// Set qc2_v to 5.0 to get weighted average of 5.0
@@ -155,13 +159,13 @@ TEST_CASE("horiz_avg") {
155159
auto diag2_v_host = diag2_f.get_view<Real *, Host>();
156160

157161
for(int i = 0; i < nlevs; ++i) {
158-
REQUIRE(std::abs(diag2_v_host(i) - wavg) < sp(1e-6));
162+
REQUIRE_THAT(diag2_v_host(i), Catch::Matchers::WithinRel(wavg, tol));
159163
}
160164

161165
auto qc3_v = qc3.get_view<Real ***>();
162166
FieldIdentifier diag3_manual_fid("qc_horiz_avg_manual",
163167
scalar3d_layout.clone().strip_dim(COL),
164-
kg / kg, grid2->name());
168+
kg / kg, grid->name());
165169
Field diag3_manual(diag3_manual_fid);
166170
diag3_manual.allocate_view();
167171
auto diag3_manual_v = diag3_manual.get_view<Real **>();

0 commit comments

Comments
 (0)