Skip to content

Commit 1373e9c

Browse files
committed
EAMxx: add column reduction utility to fields
1 parent 49fdbe3 commit 1373e9c

File tree

3 files changed

+223
-0
lines changed

3 files changed

+223
-0
lines changed

components/eamxx/src/share/field/field_utils.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define SCREAM_FIELD_UTILS_HPP
33

44
#include "share/field/field_utils_impl.hpp"
5+
#include "share/field/field_utils_impl_colred.hpp"
56

67
namespace scream {
78

@@ -111,6 +112,17 @@ void perturb (const Field& f,
111112
impl::perturb<ST>(f, engine, pdf, base_seed, level_mask, dof_gids);
112113
}
113114

115+
template <typename ST>
116+
Field column_reduction(const Field &f1, const Field &f2,
117+
const ekat::Comm *comm = nullptr) {
118+
EKAT_REQUIRE_MSG(f1.is_allocated() && f2.is_allocated(),
119+
"Error! Input fields must be allocated.");
120+
EKAT_REQUIRE_MSG(f1.data_type() == f2.data_type(),
121+
"Error! Input fields must have matching data types.");
122+
123+
return impl::column_reduction<ST>(f1, f2, comm);
124+
}
125+
114126
template<typename ST>
115127
ST frobenius_norm(const Field& f, const ekat::Comm* comm = nullptr)
116128
{
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#ifndef SCREAM_FIELD_UTILS_IMPL_COLRED_HPP
2+
#define SCREAM_FIELD_UTILS_IMPL_COLRED_HPP
3+
4+
#include "ekat/kokkos/ekat_kokkos_utils.hpp"
5+
#include "ekat/mpi/ekat_comm.hpp"
6+
#include "share/field/field.hpp"
7+
8+
namespace scream {
9+
namespace impl {
10+
11+
// Utility to compute the reduction of a field along its column dimension.
12+
// This is equivalent to einsum('i,i...k->...k', f1, f2); i is the column.
13+
// The layouts are such that:
14+
// - The first dimension is for the columns (COL)
15+
// - There can be only up to 3 dimensions
16+
17+
template <typename ST>
18+
Field column_reduction(const Field &f1, const Field &f2, const ekat::Comm *co) {
19+
using KT = ekat::KokkosTypes<DefaultDevice>;
20+
using RangePolicy = Kokkos::RangePolicy<Field::device_t::execution_space>;
21+
using TeamPolicy = Kokkos::TeamPolicy<Field::device_t::execution_space>;
22+
using TeamMember = typename TeamPolicy::member_type;
23+
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>;
24+
using namespace ShortFieldTagsNames;
25+
26+
const auto &l1 = f1.get_header().get_identifier().get_layout();
27+
28+
EKAT_REQUIRE_MSG(l1.rank() == 1,
29+
"Error! First field f1 must be rank-1.\n"
30+
"The input has rank "
31+
<< l1.rank() << ".\n");
32+
EKAT_REQUIRE_MSG(l1.tags() == std::vector<FieldTag>({COL}),
33+
"Error! First field f1 must have a column dimension.\n"
34+
"The input f1 layout is "
35+
<< l1.tags() << ".\n");
36+
37+
const auto &n2 = f2.get_header().get_identifier().name();
38+
const auto &l2 = f2.get_header().get_identifier().get_layout();
39+
const auto &u2 = f2.get_header().get_identifier().get_units();
40+
const auto &g2 = f2.get_header().get_identifier().get_grid_name();
41+
42+
EKAT_REQUIRE_MSG(l2.rank() <= 3,
43+
"Error! Second field f2 must be at most rank-3.\n"
44+
"The input f2 rank is "
45+
<< l2.rank() << ".\n");
46+
EKAT_REQUIRE_MSG(l2.tags()[0] == COL,
47+
"Error! Second field f2 must have a column dimension.\n"
48+
"The input f2 layout is "
49+
<< l2.tags() << ".\n");
50+
EKAT_REQUIRE_MSG(
51+
l1.dim(0) == l2.dim(0),
52+
"Error! The two input fields must have the same dimension along "
53+
"which we are taking the reducing the field.\n"
54+
"The first field f1 has dimension "
55+
<< l1.dim(0)
56+
<< " while "
57+
"the second field f2 has dimension "
58+
<< l2.dim(0) << ".\n");
59+
60+
auto v1 = f1.get_view<const ST *>();
61+
62+
FieldIdentifier fo_id(n2 + "_colred", l2.clone().strip_dim(0), u2, g2);
63+
Field fo(fo_id);
64+
fo.allocate_view();
65+
fo.deep_copy(0);
66+
67+
const int d0 = l2.dim(0);
68+
69+
switch(l2.rank()) {
70+
case 1: {
71+
auto v2 = f2.get_view<const ST *>();
72+
auto vo = fo.get_view<ST>();
73+
Kokkos::parallel_reduce(
74+
fo.name(), RangePolicy(0, d0),
75+
KOKKOS_LAMBDA(const int i, ST &ls) { ls += v1(i) * v2(i); }, vo);
76+
} break;
77+
case 2: {
78+
auto v2 = f2.get_view<const ST **>();
79+
auto vo = fo.get_view<ST *>();
80+
const int d1 = l2.dim(1);
81+
auto p = ESU::get_default_team_policy(d1, d0);
82+
Kokkos::parallel_for(
83+
fo.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) {
84+
const int j = tm.league_rank();
85+
Kokkos::parallel_reduce(
86+
Kokkos::TeamVectorRange(tm, d0),
87+
[&](int i, ST &ac) { ac += v1(i) * v2(i, j); }, vo(j));
88+
});
89+
} break;
90+
case 3: {
91+
auto v2 = f2.get_view<const ST ***>();
92+
auto vo = fo.get_view<ST **>();
93+
const int d1 = l2.dim(1);
94+
const int d2 = l2.dim(2);
95+
auto p = ESU::get_default_team_policy(d1 * d2, d0);
96+
Kokkos::parallel_for(
97+
fo.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) {
98+
const int idx = tm.league_rank();
99+
const int j = idx / d2;
100+
const int k = idx % d2;
101+
Kokkos::parallel_reduce(
102+
Kokkos::TeamVectorRange(tm, d0),
103+
[&](int i, ST &ac) { ac += v1(i) * v2(i, j, k); }, vo(j, k));
104+
});
105+
} break;
106+
default:
107+
EKAT_ERROR_MSG("Error! Unsupported field rank.\n");
108+
}
109+
110+
if(co) {
111+
Kokkos::fence();
112+
fo.sync_to_host();
113+
co->all_reduce(fo.template get_internal_view_data<ST, Host>(),
114+
l2.size() / l2.dim(0), MPI_SUM);
115+
fo.sync_to_dev();
116+
}
117+
return fo;
118+
}
119+
120+
} // namespace impl
121+
} // namespace scream
122+
123+
#endif // SCREAM_FIELD_UTILS_IMPL_COLRED_HPP

components/eamxx/src/share/tests/field_utils.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,94 @@ TEST_CASE("utils") {
126126
REQUIRE(field_sum<Real>(f1,&comm)==gsum);
127127
}
128128

129+
SECTION("column_reduction") {
130+
using RPDF = std::uniform_real_distribution<Real>;
131+
auto engine = setup_random_test();
132+
RPDF pdf(0, 1);
133+
134+
int dim0 = 3;
135+
int dim1 = 9;
136+
int dim2 = 2;
137+
FieldIdentifier f00("f", {{COL}, {dim0}}, m / s, "g");
138+
Field field00(f00);
139+
field00.allocate_view();
140+
field00.sync_to_host();
141+
auto v00 = field00.get_strided_view<Real *, Host>();
142+
for(int i = 0; i < dim0; ++i) {
143+
v00(i) = (i + 1) / sp(6);
144+
}
145+
field00.sync_to_dev();
146+
147+
FieldIdentifier f10("f", {{COL, CMP}, {dim0, dim1}}, m / s, "g");
148+
FieldIdentifier f11("f", {{COL, LEV}, {dim0, dim2}}, m / s, "g");
149+
FieldIdentifier f20("f", {{COL, CMP, LEV}, {dim0, dim1, dim2}}, m / s, "g");
150+
151+
Field field10(f10);
152+
Field field11(f11);
153+
Field field20(f20);
154+
field10.allocate_view();
155+
field11.allocate_view();
156+
field20.allocate_view();
157+
158+
randomize(field10, engine, pdf);
159+
randomize(field11, engine, pdf);
160+
randomize(field20, engine, pdf);
161+
162+
FieldIdentifier F_x("fx", {{COL}, {dim1}}, m/s, "g");
163+
FieldIdentifier F_y("fy", {{LEV}, {dim2}}, m/s, "g");
164+
165+
Field field_x(F_x);
166+
Field field_y(F_y);
167+
168+
REQUIRE_THROWS(column_reduction<Real>(field00, field_x)); // x not allocated
169+
170+
field_x.allocate_view();
171+
field_y.allocate_view();
172+
173+
REQUIRE_THROWS(column_reduction<Real>(field_x, field_y)); // unmatching layout
174+
REQUIRE_THROWS(column_reduction<Real>(field11, field11)); // wrong f1 layout
175+
176+
Field result;
177+
178+
result = column_reduction<Real>(field00, field00);
179+
result.sync_to_host();
180+
auto v = result.get_view<Real, Host>();
181+
REQUIRE(v() == (1 / sp(36) + 4 / sp(36) + 9 / sp(36)));
182+
183+
result = column_reduction<Real>(field00, field10);
184+
REQUIRE(result.get_header().get_identifier().get_layout().tags() ==
185+
std::vector<FieldTag>({CMP}));
186+
REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim1);
187+
188+
result = column_reduction<Real>(field00, field11);
189+
REQUIRE(result.get_header().get_identifier().get_layout().tags() ==
190+
std::vector<FieldTag>({LEV}));
191+
REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim2);
192+
193+
result = column_reduction<Real>(field00, field20);
194+
REQUIRE(result.get_header().get_identifier().get_layout().tags() ==
195+
std::vector<FieldTag>({CMP, LEV}));
196+
REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim1);
197+
REQUIRE(result.get_header().get_identifier().get_layout().dim(1) == dim2);
198+
199+
field20.sync_to_host();
200+
auto manual_result = result.clone();
201+
manual_result.deep_copy(0);
202+
manual_result.sync_to_host();
203+
auto v2 = field20.get_strided_view<Real ***, Host>();
204+
auto mr = manual_result.get_strided_view<Real **, Host>();
205+
for(int i = 0; i < dim0; ++i) {
206+
for(int j = 0; j < dim1; ++j) {
207+
for(int k = 0; k < dim2; ++k) {
208+
mr(j, k) += v00(i) * v2(i, j, k);
209+
}
210+
}
211+
}
212+
field20.sync_to_dev();
213+
manual_result.sync_to_dev();
214+
REQUIRE(views_are_equal(result, manual_result));
215+
}
216+
129217
SECTION ("frobenius") {
130218

131219
auto v1 = f1.get_strided_view<Real**>();

0 commit comments

Comments
 (0)