Skip to content

Commit f1dd1b0

Browse files
committed
EAMxx: make contractions take masked fields
1 parent 5cb56bb commit f1dd1b0

File tree

5 files changed

+153
-21
lines changed

5 files changed

+153
-21
lines changed

components/eamxx/src/diagnostics/vert_contract.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,12 @@ void VertContractDiag::compute_diagnostic_impl() {
216216
}
217217

218218
// call the vert_contraction impl that will take care of everything
219-
vert_contraction<Real>(d, f, m_weighting);
219+
// if f has a mask and we are averaging, need to call the avg specialization
220+
if (m_contract_method == "avg" && f.get_header().has_extra_data("mask_data")) {
221+
vert_contraction<Real,1>(d, f, m_weighting);
222+
} else {
223+
vert_contraction<Real,0>(d, f, m_weighting);
224+
}
220225
}
221226

222227
} // namespace scream

components/eamxx/src/share/field/field_utils.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ void horiz_contraction(const Field &f_out, const Field &f_in,
178178
"Error! Weight field must have the same data type as input fields.");
179179

180180
// All good, call the implementation
181-
impl::horiz_contraction<ST>(f_out, f_in, weight, comm);
181+
impl::horiz_contraction<ST,1>(f_out, f_in, weight, comm);
182182
}
183183

184184
// Utility to compute the contraction of a field along its level dimension.
@@ -191,7 +191,7 @@ void horiz_contraction(const Field &f_out, const Field &f_in,
191191
// - rank-1, with only LEV/ILEV dimension
192192
// - rank-2, with only COL and LEV/ILEV dimensions
193193
// NOTE: we assume the LEV/ILEV dimension is NOT partitioned.
194-
template <typename ST>
194+
template <typename ST, int AVG = 0>
195195
void vert_contraction(const Field &f_out, const Field &f_in, const Field &weight) {
196196
using namespace ShortFieldTagsNames;
197197

@@ -266,7 +266,7 @@ void vert_contraction(const Field &f_out, const Field &f_in, const Field &weight
266266
"Error! Weight field must have the same data type as input field.");
267267

268268
// All good, call the implementation
269-
impl::vert_contraction<ST>(f_out, f_in, weight);
269+
impl::vert_contraction<ST, AVG>(f_out, f_in, weight);
270270
}
271271

272272
template<typename ST>

components/eamxx/src/share/field/field_utils_impl.hpp

Lines changed: 91 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ void perturb (Field& f,
306306
}
307307
}
308308

309-
template <typename ST>
309+
template <typename ST, int AVG>
310310
void horiz_contraction(const Field &f_out, const Field &f_in,
311311
const Field &weight, const ekat::Comm *comm) {
312312
using KT = ekat::KokkosTypes<DefaultDevice>;
@@ -318,34 +318,63 @@ void horiz_contraction(const Field &f_out, const Field &f_in,
318318
auto l_out = f_out.get_header().get_identifier().get_layout();
319319
auto l_in = f_in.get_header().get_identifier().get_layout();
320320

321+
auto is_masked = f_in.get_header().has_extra_data("mask_data");
322+
321323
auto v_w = weight.get_view<const ST *>();
322324

323325
const int ncols = l_in.dim(0);
324326

327+
bool is_avg = AVG; // 1 is avg; 0 is sum
328+
325329
switch(l_in.rank()) {
326330
case 1: {
327331
auto v_in = f_in.get_view<const ST *>();
332+
auto v_m = is_masked ? f_in.get_header().get_extra_data<Field>("mask_data").get_view<const ST *>() : v_in;
328333
auto v_out = f_out.get_view<ST>();
334+
ST n = 0, d = 0;
329335
Kokkos::parallel_reduce(
330336
f_out.name(), RangePolicy(0, ncols),
331-
KOKKOS_LAMBDA(const int i, ST &ls) { ls += v_w(i) * v_in(i); },
332-
v_out);
337+
KOKKOS_LAMBDA(const int i, ST &n_acc, ST &d_acc) {
338+
auto mask = is_masked ? v_m(i) : ST(1.0);
339+
n_acc += v_w(i) * v_in(i) * mask;
340+
d_acc += v_w(i) * mask;
341+
},
342+
Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
343+
if (is_avg) {
344+
ST tmp = d != 0 ? n / d : 0;
345+
Kokkos::deep_copy(v_out, tmp);
346+
} else {
347+
Kokkos::deep_copy(v_out, n);
348+
}
333349
} break;
334350
case 2: {
335351
auto v_in = f_in.get_view<const ST **>();
352+
auto v_m = is_masked ? f_in.get_header().get_extra_data<Field>("mask_data").get_view<const ST **>() : v_in;
336353
auto v_out = f_out.get_view<ST *>();
337354
const int d1 = l_in.dim(1);
338355
auto p = ESU::get_default_team_policy(d1, ncols);
339356
Kokkos::parallel_for(
340357
f_out.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) {
341358
const int j = tm.league_rank();
359+
ST n = 0, d = 0;
342360
Kokkos::parallel_reduce(
343361
Kokkos::TeamVectorRange(tm, ncols),
344-
[&](int i, ST &ac) { ac += v_w(i) * v_in(i, j); }, v_out(j));
362+
[&](int i, ST &n_acc, ST &d_acc) {
363+
auto mask = is_masked ? v_m(i, j) : ST(1.0);
364+
n_acc += v_w(i) * v_in(i, j) * mask;
365+
d_acc += v_w(i) * mask;
366+
},
367+
Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
368+
if (is_avg) {
369+
v_out(j) = d != 0 ? n / d : 0;
370+
} else {
371+
v_out(j) = n;
372+
}
345373
});
346374
} break;
347375
case 3: {
348376
auto v_in = f_in.get_view<const ST ***>();
377+
auto v_m = is_masked ? f_in.get_header().get_extra_data<Field>("mask_data").get_view<const ST ***>() : v_in;
349378
auto v_out = f_out.get_view<ST **>();
350379
const int d1 = l_in.dim(1);
351380
const int d2 = l_in.dim(2);
@@ -355,10 +384,20 @@ void horiz_contraction(const Field &f_out, const Field &f_in,
355384
const int idx = tm.league_rank();
356385
const int j = idx / d2;
357386
const int k = idx % d2;
387+
ST n = 0, d = 0;
358388
Kokkos::parallel_reduce(
359389
Kokkos::TeamVectorRange(tm, ncols),
360-
[&](int i, ST &ac) { ac += v_w(i) * v_in(i, j, k); },
361-
v_out(j, k));
390+
[&](int i, ST &n_acc, ST &d_acc) {
391+
auto mask = is_masked ? v_m(i, j, k) : ST(1.0);
392+
n_acc += v_w(i) * v_in(i, j, k) * mask;
393+
d_acc += v_w(i) * mask;
394+
},
395+
Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
396+
if (is_avg) {
397+
v_out(j, k) = d != 0 ? n / d : 0;
398+
} else {
399+
v_out(j, k) = n;
400+
}
362401
});
363402
} break;
364403
default:
@@ -377,7 +416,7 @@ void horiz_contraction(const Field &f_out, const Field &f_in,
377416
}
378417
}
379418

380-
template <typename ST>
419+
template <typename ST, int AVG>
381420
void vert_contraction(const Field &f_out, const Field &f_in, const Field &weight) {
382421
using KT = ekat::KokkosTypes<DefaultDevice>;
383422
using RangePolicy = Kokkos::RangePolicy<Field::device_t::execution_space>;
@@ -389,6 +428,10 @@ void vert_contraction(const Field &f_out, const Field &f_in, const Field &weight
389428
auto l_in = f_in.get_header().get_identifier().get_layout();
390429
auto l_w = weight.get_header().get_identifier().get_layout();
391430

431+
auto is_masked = f_in.get_header().has_extra_data("mask_data");
432+
433+
bool is_avg = AVG; // 1 is avg, 0 is sum
434+
392435
const int nlevs = l_in.dim(l_in.rank() - 1);
393436

394437
// To avoid duplicating code for the 1d and 2d weight cases,
@@ -404,32 +447,55 @@ void vert_contraction(const Field &f_out, const Field &f_in, const Field &weight
404447

405448
switch(l_in.rank()) {
406449
case 1: {
407-
auto v_w = weight.get_view<const ST *>();
408450
auto v_in = f_in.get_view<const ST *>();
451+
auto v_m = is_masked ? f_in.get_header().get_extra_data<Field>("mask_data").get_view<const ST *>() : v_in;
409452
auto v_out = f_out.get_view<ST>();
453+
ST n = 0, d = 0;
410454
Kokkos::parallel_reduce(
411455
f_out.name(), RangePolicy(0, nlevs),
412-
KOKKOS_LAMBDA(const int i, ST &ls) { ls += v_w(i) * v_in(i); },
413-
v_out);
456+
KOKKOS_LAMBDA(const int i, ST &n_acc, ST &d_acc) {
457+
auto mask = is_masked ? v_m(i) : ST(1.0);
458+
auto w = w1d(i);
459+
n_acc += w * v_in(i) * mask;
460+
d_acc += w * mask;
461+
},
462+
Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
463+
if (is_avg) {
464+
ST tmp = d != 0 ? n / d : 0;
465+
Kokkos::deep_copy(v_out, tmp);
466+
} else {
467+
Kokkos::deep_copy(v_out, n);
468+
}
414469
} break;
415470
case 2: {
416471
auto v_in = f_in.get_view<const ST **>();
472+
auto v_m = is_masked ? f_in.get_header().get_extra_data<Field>("mask_data").get_view<const ST **>() : v_in;
417473
auto v_out = f_out.get_view<ST *>();
418474
const int d0 = l_in.dim(0);
419475
auto p = ESU::get_default_team_policy(d0, nlevs);
420476
Kokkos::parallel_for(
421477
f_out.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) {
422478
const int i = tm.league_rank();
479+
ST n = 0, d = 0;
423480
Kokkos::parallel_reduce(
424481
Kokkos::TeamVectorRange(tm, nlevs),
425-
[&](int j, ST &ac) {
426-
ac += w_is_1d ? w1d(j) * v_in(i, j) : w2d(i, j) * v_in(i, j);
482+
[&](int j, ST &n_acc, ST &d_acc) {
483+
auto mask = is_masked ? v_m(i, j) : ST(1.0);
484+
auto w = w_is_1d ? w1d(j) : w2d(i, j);
485+
n_acc += w * v_in(i, j) * mask;
486+
d_acc += w * mask;
427487
},
428-
v_out(i));
488+
Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
489+
if (is_avg) {
490+
v_out(i) = d != 0 ? n / d : 0;
491+
} else {
492+
v_out(i) = n;
493+
}
429494
});
430495
} break;
431496
case 3: {
432497
auto v_in = f_in.get_view<const ST ***>();
498+
auto v_m = is_masked ? f_in.get_header().get_extra_data<Field>("mask_data").get_view<const ST ***>() : v_in;
433499
auto v_out = f_out.get_view<ST **>();
434500
const int d0 = l_in.dim(0);
435501
const int d1 = l_in.dim(1);
@@ -439,13 +505,21 @@ void vert_contraction(const Field &f_out, const Field &f_in, const Field &weight
439505
const int idx = tm.league_rank();
440506
const int i = idx / d1;
441507
const int j = idx % d1;
508+
ST n = 0, d = 0;
442509
Kokkos::parallel_reduce(
443510
Kokkos::TeamVectorRange(tm, nlevs),
444-
[&](int k, ST &ac) {
445-
ac += w_is_1d ? w1d(k) * v_in(i, j, k)
446-
: w2d(i, k) * v_in(i, j, k);
511+
[&](int k, ST &n_acc, ST &d_acc) {
512+
auto mask = is_masked ? v_m(i, j, k) : ST(1.0);
513+
auto w = w_is_1d ? w1d(k) : w2d(i, k);
514+
n_acc += w * v_in(i, j, k) * mask;
515+
d_acc += w * mask;
447516
},
448-
v_out(i, j));
517+
Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
518+
if (is_avg) {
519+
v_out(i, j) = d != 0 ? n / d : 0;
520+
} else {
521+
v_out(i, j) = n;
522+
}
449523
});
450524
} break;
451525
default:

components/eamxx/src/share/io/scorpio_output.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,11 @@ init_diagnostics ()
908908
m_track_avg_cnt = m_track_avg_cnt || m_avg_type!=OutputAvgType::Instant;
909909
diag_avg_cnt_name = "_" + diag->name();
910910
}
911+
else if (diag_field.get_header().has_extra_data("mask_data")) {
912+
params.set<double>("mask_value", m_fill_value);
913+
m_track_avg_cnt = m_track_avg_cnt || m_avg_type!=OutputAvgType::Instant;
914+
diag_avg_cnt_name = "_" + diag_field.name();
915+
}
911916

912917
// If specified, set avg_cnt tracking for this diagnostic.
913918
if (m_track_avg_cnt) {

components/eamxx/src/share/tests/field_utils.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,30 @@ TEST_CASE("utils") {
212212
Real wavg = sp(sum_n_sq(dim0)) / sp(sum_n(dim0) * sum_n(dim0));
213213
REQUIRE_THAT(v(), Catch::Matchers::WithinRel(wavg, tol));
214214

215+
// Repeat but with masked values
216+
result = fieldsc.clone();
217+
// inject a mask as the last entry
218+
auto field00_masked = field00.clone();
219+
auto mask_of_field00 = field00_masked.clone();
220+
mask_of_field00.deep_copy(sp(1.0));
221+
mask_of_field00.sync_to_host();
222+
auto mask = mask_of_field00.get_view<Real *, Host>();
223+
mask(dim0 - 1) = sp(0.0);
224+
mask_of_field00.sync_to_dev();
225+
field00_masked.get_header().set_extra_data("mask_data", mask_of_field00);
226+
field00_masked.sync_to_dev();
227+
horiz_contraction<Real>(result, field00_masked, field00);
228+
result.sync_to_host();
229+
v = result.get_view<Real, Host>();
230+
Real wavg_sum1 = 0;
231+
Real wavg_sum2 = 0;
232+
auto wavg_v00 = field00.get_view<const Real *, Host>();
233+
for(int i = 0; i < dim0; ++i) {
234+
wavg_sum1 += mask(i) * wavg_v00(i) * wavg_v00(i);
235+
wavg_sum2 += mask(i) * wavg_v00(i);
236+
}
237+
REQUIRE_THAT(v(), Catch::Matchers::WithinRel(wavg_sum1/wavg_sum2, tol));
238+
215239
// Test higher-order cases
216240
result = field_z.clone();
217241
horiz_contraction<Real>(result, field10, field00);
@@ -351,6 +375,30 @@ TEST_CASE("utils") {
351375
Real havg = sp(sum_n_sq(dim2)) / sp(sum_n(dim2) * sum_n(dim2));
352376
REQUIRE_THAT(v(), Catch::Matchers::WithinRel(havg, tol));
353377

378+
// Repeat but with masked values
379+
result = fieldsc.clone();
380+
// inject a mask as the last entry
381+
auto field00_masked = field00.clone();
382+
auto mask_of_field00 = field00_masked.clone();
383+
mask_of_field00.deep_copy(sp(1.0));
384+
mask_of_field00.sync_to_host();
385+
auto mask = mask_of_field00.get_view<Real *, Host>();
386+
mask(dim0 - 1) = sp(0.0);
387+
mask_of_field00.sync_to_dev();
388+
field00_masked.get_header().set_extra_data("mask_data", mask_of_field00);
389+
field00_masked.sync_to_dev();
390+
vert_contraction<Real,1>(result, field00_masked, field00);
391+
result.sync_to_host();
392+
v = result.get_view<Real, Host>();
393+
Real wavg_sum1 = sp(0.0);
394+
Real wavg_sum2 = sp(0.0);
395+
auto wavg_v00 = field00.get_view<const Real *, Host>();
396+
for(int i = 0; i < dim2; ++i) {
397+
wavg_sum1 += mask(i) * wavg_v00(i) * wavg_v00(i);
398+
wavg_sum2 += mask(i) * wavg_v00(i);
399+
}
400+
REQUIRE_THAT(v(), Catch::Matchers::WithinRel(wavg_sum1/wavg_sum2, tol));
401+
354402
// Test higher-order cases
355403
result = field_x.clone();
356404
vert_contraction<Real>(result, field10, field00);

0 commit comments

Comments
 (0)