7
7
8
8
namespace scream {
9
9
10
- std::shared_ptr<GridsManager> create_gm (const ekat::Comm &comm, const int ncols, const int nlevs) {
11
- const int num_global_cols = ncols * comm.size ();
10
+ std::shared_ptr<GridsManager> create_gm (const ekat::Comm &comm, const int ngcols, const int nlevs) {
12
11
13
12
using vos_t = std::vector<std::string>;
14
13
ekat::ParameterList gm_params;
15
14
gm_params.set (" grids_names" , vos_t {" Point Grid" });
16
15
auto &pl = gm_params.sublist (" Point Grid" );
17
16
pl.set <std::string>(" type" , " point_grid" );
18
17
pl.set (" aliases" , vos_t {" Physics" });
19
- pl.set <int >(" number_of_global_columns" , num_global_cols );
18
+ pl.set <int >(" number_of_global_columns" , ngcols );
20
19
pl.set <int >(" number_of_vertical_levels" , nlevs);
21
20
22
21
auto gm = create_mesh_free_grids_manager (comm, gm_params);
@@ -30,7 +29,7 @@ TEST_CASE("zonal_avg") {
30
29
using namespace ekat ::units;
31
30
32
31
// A numerical tolerance
33
- auto tol = std::numeric_limits<Real>::epsilon () * 100 ;
32
+ const auto tol = std::numeric_limits<Real>::epsilon () * 100 ;
34
33
35
34
// A world comm
36
35
ekat::Comm comm (MPI_COMM_WORLD);
@@ -41,9 +40,10 @@ TEST_CASE("zonal_avg") {
41
40
// Create a grids manager - single column for these tests
42
41
constexpr int nlevs = 3 ;
43
42
constexpr int dim3 = 4 ;
44
- const int ngcols = 6 * comm. size () ;
43
+ const int ncols = 6 ;
45
44
const int nlats = 4 ;
46
45
46
+ const int ngcols = ncols * comm.size ();
47
47
auto gm = create_gm (comm, ngcols, nlevs);
48
48
auto grid = gm->get_grid (" Physics" );
49
49
@@ -56,16 +56,17 @@ TEST_CASE("zonal_avg") {
56
56
auto lat_view_h = lat.get_view <Real *, Host>();
57
57
const Real lat_delta = sp (180.0 ) / nlats;
58
58
std::vector<Real> zonal_areas (nlats, 0.0 );
59
- for (int i = 0 ; i < ngcols ; i++) {
59
+ for (int i = 0 ; i < ncols ; i++) {
60
60
lat_view_h (i) = sp (-90.0 ) + (i % nlats + sp (0.5 )) * lat_delta;
61
61
zonal_areas[i % nlats] += area_view_h[i];
62
62
}
63
63
lat.sync_to_dev ();
64
+ comm.all_reduce (zonal_areas.data (), zonal_areas.size (), MPI_SUM);
64
65
65
66
// Input (randomized) qc
66
- FieldLayout scalar1d_layout{{COL}, {ngcols }};
67
- FieldLayout scalar2d_layout{{COL, LEV}, {ngcols , nlevs}};
68
- FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols , dim3, nlevs}};
67
+ FieldLayout scalar1d_layout{{COL}, {ncols }};
68
+ FieldLayout scalar2d_layout{{COL, LEV}, {ncols , nlevs}};
69
+ FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ncols , dim3, nlevs}};
69
70
70
71
FieldIdentifier qc1_id (" qc" , scalar1d_layout, kg / kg, grid->name ());
71
72
FieldIdentifier qc2_fid (" qc" , scalar2d_layout, kg / kg, grid->name ());
@@ -135,10 +136,12 @@ TEST_CASE("zonal_avg") {
135
136
// calculate the zonal average
136
137
auto qc1_view_h = qc1.get_view <const Real *, Host>();
137
138
auto diag0_view_h = diag0_field.get_view <Real *, Host>();
138
- for (int i = 0 ; i < ngcols ; i++) {
139
+ for (int i = 0 ; i < ncols ; i++) {
139
140
const int nlat = i % nlats;
140
141
diag0_view_h (nlat) += area_view_h (i) / zonal_areas[nlat] * qc1_view_h (i);
141
142
}
143
+ comm.all_reduce (diag0_field.template get_internal_view_data <Real, Host>(),
144
+ diag0_layout.size (), MPI_SUM);
142
145
diag0_field.sync_to_dev ();
143
146
144
147
// Compare
@@ -149,7 +152,7 @@ TEST_CASE("zonal_avg") {
149
152
const Real zavg1 = sp (1.0 );
150
153
qc1.deep_copy (zavg1);
151
154
diag1->compute_diagnostic ();
152
- auto diag1_view_host = diag1_field.get_view <Real *, Host>();
155
+ auto diag1_view_host = diag1_field.get_view <const Real *, Host>();
153
156
for (int nlat = 0 ; nlat < nlats; nlat++) {
154
157
REQUIRE_THAT (diag1_view_host (nlat), Catch::Matchers::WithinRel (zavg1, tol));
155
158
}
@@ -163,7 +166,7 @@ TEST_CASE("zonal_avg") {
163
166
diag2->compute_diagnostic ();
164
167
auto diag2_field = diag2->get_diagnostic ();
165
168
166
- auto diag2_view_host = diag2_field.get_view <Real **, Host>();
169
+ auto diag2_view_host = diag2_field.get_view <const Real **, Host>();
167
170
for (int i = 0 ; i < nlevs; ++i) {
168
171
for (int nlat = 0 ; nlat < nlats; nlat++) {
169
172
REQUIRE_THAT (diag2_view_host (nlat, i), Catch::Matchers::WithinRel (zavg2, tol));
@@ -176,16 +179,18 @@ TEST_CASE("zonal_avg") {
176
179
FieldIdentifier diag3m_id (" qc_zonal_avg_manual" , diag3m_layout, kg / kg, grid->name ());
177
180
Field diag3m_field (diag3m_id);
178
181
diag3m_field.allocate_view ();
179
- auto qc3_view_h = qc3.get_view <Real ***, Host>();
182
+ auto qc3_view_h = qc3.get_view <const Real ***, Host>();
180
183
auto diag3m_view_h = diag3m_field.get_view <Real ***, Host>();
181
- for (int i = 0 ; i < ngcols ; i++) {
184
+ for (int i = 0 ; i < ncols ; i++) {
182
185
const int nlat = i % nlats;
183
186
for (int j = 0 ; j < dim3; j++) {
184
187
for (int k = 0 ; k < nlevs; k++) {
185
188
diag3m_view_h (nlat, j, k) += area_view_h (i) / zonal_areas[nlat] * qc3_view_h (i, j, k);
186
189
}
187
190
}
188
191
}
192
+ comm.all_reduce (diag3m_field.template get_internal_view_data <Real, Host>(),
193
+ diag3m_layout.size (), MPI_SUM);
189
194
diag3m_field.sync_to_dev ();
190
195
diag3->set_required_field (qc3);
191
196
diag3->initialize (t0, RunType::Initial);
0 commit comments