@@ -33,6 +33,10 @@ TEST_CASE("horiz_avg") {
33
33
using TeamMember = typename TeamPolicy::member_type;
34
34
using KT = ekat::KokkosTypes<DefaultDevice>;
35
35
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>;
36
+
37
+ // A numerical tolerance
38
+ auto tol = std::numeric_limits<Real>::epsilon () * 100 ;
39
+
36
40
// A world comm
37
41
ekat::Comm comm (MPI_COMM_WORLD);
38
42
@@ -44,19 +48,17 @@ TEST_CASE("horiz_avg") {
44
48
constexpr int dim3 = 4 ;
45
49
const int ngcols = 6 * comm.size ();
46
50
47
- auto gm1 = create_gm (comm, ngcols, 1 );
48
- auto gm2 = create_gm (comm, ngcols, nlevs);
49
- auto grid1 = gm1->get_grid (" Physics" );
50
- auto grid2 = gm2->get_grid (" Physics" );
51
+ auto gm = create_gm (comm, ngcols, nlevs);
52
+ auto grid = gm->get_grid (" Physics" );
51
53
52
54
// Input (randomized) qc
53
55
FieldLayout scalar1d_layout{{COL}, {ngcols}};
54
56
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}};
55
57
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols, dim3, nlevs}};
56
58
57
- FieldIdentifier qc1_fid (" qc" , scalar1d_layout, kg / kg, grid1 ->name ());
58
- FieldIdentifier qc2_fid (" qc" , scalar2d_layout, kg / kg, grid2 ->name ());
59
- FieldIdentifier qc3_fid (" qc" , scalar3d_layout, kg / kg, grid2 ->name ());
59
+ FieldIdentifier qc1_fid (" qc" , scalar1d_layout, kg / kg, grid ->name ());
60
+ FieldIdentifier qc2_fid (" qc" , scalar2d_layout, kg / kg, grid ->name ());
61
+ FieldIdentifier qc3_fid (" qc" , scalar3d_layout, kg / kg, grid ->name ());
60
62
61
63
Field qc1 (qc1_fid);
62
64
Field qc2 (qc2_fid);
@@ -78,8 +80,8 @@ TEST_CASE("horiz_avg") {
78
80
register_diagnostics ();
79
81
80
82
ekat::ParameterList params;
81
- // REQUIRE_THROWS(diag_factory.create("HorizAvgDiag", comm,
82
- // params)); // No 'field_name' parameter
83
+ REQUIRE_THROWS (diag_factory.create (" HorizAvgDiag" , comm,
84
+ params)); // No 'field_name' parameter
83
85
84
86
// Set time for qc and randomize its values
85
87
qc1.get_header ().get_tracking ().update_time_stamp (t0);
@@ -90,16 +92,16 @@ TEST_CASE("horiz_avg") {
90
92
randomize (qc3, engine, pdf);
91
93
92
94
// Create and set up the diagnostic
93
- params.set (" grid_name" , grid1 ->name ());
95
+ params.set (" grid_name" , grid ->name ());
94
96
params.set <std::string>(" field_name" , " qc" );
95
97
auto diag1 = diag_factory.create (" HorizAvgDiag" , comm, params);
96
98
auto diag2 = diag_factory.create (" HorizAvgDiag" , comm, params);
97
99
auto diag3 = diag_factory.create (" HorizAvgDiag" , comm, params);
98
- diag1->set_grids (gm1 );
99
- diag2->set_grids (gm2 );
100
- diag3->set_grids (gm2 );
100
+ diag1->set_grids (gm );
101
+ diag2->set_grids (gm );
102
+ diag3->set_grids (gm );
101
103
102
- auto area = grid1 ->get_geometry_data (" area" );
104
+ auto area = grid ->get_geometry_data (" area" );
103
105
104
106
diag1->set_required_field (qc1);
105
107
diag1->initialize (t0, RunType::Initial);
@@ -109,7 +111,7 @@ TEST_CASE("horiz_avg") {
109
111
110
112
FieldIdentifier diag0_fid (" qc_horiz_avg_manual" ,
111
113
scalar1d_layout.clone ().strip_dim (COL), kg / kg,
112
- grid1 ->name ());
114
+ grid ->name ());
113
115
Field diag0 (diag0_fid);
114
116
diag0.allocate_view ();
115
117
auto diag0_v = diag0.get_view <Real>();
@@ -139,7 +141,9 @@ TEST_CASE("horiz_avg") {
139
141
Kokkos::deep_copy (qc1_v, wavg);
140
142
diag1->compute_diagnostic ();
141
143
auto diag1_v2_host = diag1_f.get_view <Real, Host>();
142
- REQUIRE (std::abs (diag1_v2_host () - wavg) < sp (1e-6 ));
144
+ REQUIRE_THAT (diag1_v2_host (),
145
+ Catch::Matchers::WithinRel (
146
+ wavg, tol)); // Catch2's floating point comparison
143
147
144
148
// other diags
145
149
// Set qc2_v to 5.0 to get weighted average of 5.0
@@ -155,13 +159,13 @@ TEST_CASE("horiz_avg") {
155
159
auto diag2_v_host = diag2_f.get_view <Real *, Host>();
156
160
157
161
for (int i = 0 ; i < nlevs; ++i) {
158
- REQUIRE ( std::abs ( diag2_v_host (i) - wavg) < sp ( 1e-6 ));
162
+ REQUIRE_THAT ( diag2_v_host (i), Catch::Matchers::WithinRel ( wavg, tol ));
159
163
}
160
164
161
165
auto qc3_v = qc3.get_view <Real ***>();
162
166
FieldIdentifier diag3_manual_fid (" qc_horiz_avg_manual" ,
163
167
scalar3d_layout.clone ().strip_dim (COL),
164
- kg / kg, grid2 ->name ());
168
+ kg / kg, grid ->name ());
165
169
Field diag3_manual (diag3_manual_fid);
166
170
diag3_manual.allocate_view ();
167
171
auto diag3_manual_v = diag3_manual.get_view <Real **>();
0 commit comments