Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions components/eamxx/src/share/field/field_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,74 @@ void perturb (const Field& f,
impl::perturb<ST>(f, engine, pdf, base_seed, level_mask, dof_gids);
}

// Utility to compute the contraction of a field along its column dimension.
// This is equivalent to f_out = einsum('i,i...k->...k', weight, f_in).
// The impl is such that:
// - f_out, f_in, and weight must be provided and allocated
// - The first dimension is for the columns (COL)
// - There can be only up to 3 dimensions of f_in
template <typename ST>
void horiz_contraction(const Field &f_out, const Field &f_in,
const Field &weight, const ekat::Comm *comm = nullptr) {
using namespace ShortFieldTagsNames;

const auto &l_out = f_out.get_header().get_identifier().get_layout();

const auto &l_in = f_in.get_header().get_identifier().get_layout();

const auto &l_w = weight.get_header().get_identifier().get_layout();

// Sanity checks before handing off to the implementation
EKAT_REQUIRE_MSG(l_w.rank() == 1,
"Error! The weight field must be rank-1.\n"
"The input has rank "
<< l_w.rank() << ".\n");
EKAT_REQUIRE_MSG(l_w.tags() == std::vector<FieldTag>({COL}),
"Error! The weight field must have a column dimension.\n"
"The input f1 layout is "
<< l_w.tags() << ".\n");
EKAT_REQUIRE_MSG(l_in.rank() <= 3,
"Error! The input field must be at most rank-3.\n"
"The input f_in rank is "
<< l_in.rank() << ".\n");
EKAT_REQUIRE_MSG(l_in.tags()[0] == COL,
"Error! The input field must have a column dimension.\n"
"The input f_in layout is "
<< l_in.to_string() << ".\n");
EKAT_REQUIRE_MSG(
l_w.dim(0) == l_in.dim(0),
"Error! input and weight fields must have the same dimension along "
"which we are taking the reducing the field.\n"
"The weight field has dimension "
<< l_w.dim(0)
<< " while "
"the input field has dimension "
<< l_in.dim(0) << ".\n");
EKAT_REQUIRE_MSG(
l_in.dim(0) > 0,
"Error! The input field must have a non-zero column dimension.\n"
"The input f_in layout is "
<< l_in.to_string() << ".\n");
EKAT_REQUIRE_MSG(
l_out == l_in.clone().strip_dim(0),
"Error! The output field must have the same layout as the input field "
"without the column dimension.\n"
"The input f_in layout is "
<< l_in.to_string() << " and the output f_out layout is "
<< l_out.to_string() << ".\n");
EKAT_REQUIRE_MSG(
f_out.is_allocated() && f_in.is_allocated() && weight.is_allocated(),
"Error! All fields must be allocated.");
EKAT_REQUIRE_MSG(f_out.data_type() == f_in.data_type(),
"Error! In/out Fields have matching data types.");
EKAT_REQUIRE_MSG(
f_out.data_type() == weight.data_type(),
"Error! Weight field must have the same data type as input fields.");

// All good, call the implementation
impl::horiz_contraction<ST>(f_out, f_in, weight, comm);
}

template<typename ST>
ST frobenius_norm(const Field& f, const ekat::Comm* comm = nullptr)
{
Expand Down
73 changes: 73 additions & 0 deletions components/eamxx/src/share/field/field_utils_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include "ekat/mpi/ekat_comm.hpp"

#include "ekat/kokkos/ekat_kokkos_utils.hpp"

#include <limits>
#include <type_traits>

Expand Down Expand Up @@ -293,6 +295,77 @@ void perturb (const Field& f,
}
}

template <typename ST>
void horiz_contraction(const Field &f_out, const Field &f_in,
const Field &weight, const ekat::Comm *comm) {
using KT = ekat::KokkosTypes<DefaultDevice>;
using RangePolicy = Kokkos::RangePolicy<Field::device_t::execution_space>;
using TeamPolicy = Kokkos::TeamPolicy<Field::device_t::execution_space>;
using TeamMember = typename TeamPolicy::member_type;
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>;

auto l_out = f_out.get_header().get_identifier().get_layout();
auto l_in = f_in.get_header().get_identifier().get_layout();

auto v_w = weight.get_view<const ST *>();

const int ncols = l_in.dim(0);

switch(l_in.rank()) {
case 1: {
auto v_in = f_in.get_view<const ST *>();
auto v_out = f_out.get_view<ST>();
Kokkos::parallel_reduce(
f_out.name(), RangePolicy(0, ncols),
KOKKOS_LAMBDA(const int i, ST &ls) { ls += v_w(i) * v_in(i); },
v_out);
} break;
case 2: {
auto v_in = f_in.get_view<const ST **>();
auto v_out = f_out.get_view<ST *>();
const int d1 = l_in.dim(1);
auto p = ESU::get_default_team_policy(d1, ncols);
Kokkos::parallel_for(
f_out.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) {
const int j = tm.league_rank();
Kokkos::parallel_reduce(
Kokkos::TeamVectorRange(tm, ncols),
[&](int i, ST &ac) { ac += v_w(i) * v_in(i, j); }, v_out(j));
});
} break;
case 3: {
auto v_in = f_in.get_view<const ST ***>();
auto v_out = f_out.get_view<ST **>();
const int d1 = l_in.dim(1);
const int d2 = l_in.dim(2);
auto p = ESU::get_default_team_policy(d1 * d2, ncols);
Kokkos::parallel_for(
f_out.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) {
const int idx = tm.league_rank();
const int j = idx / d2;
const int k = idx % d2;
Kokkos::parallel_reduce(
Kokkos::TeamVectorRange(tm, ncols),
[&](int i, ST &ac) { ac += v_w(i) * v_in(i, j, k); },
v_out(j, k));
});
} break;
default:
EKAT_ERROR_MSG("Error! Unsupported field rank.\n");
}

if(comm) {
// TODO: use device-side MPI calls
// TODO: the dev ptr causes problems; revisit this later
// TODO: doing cuda-aware MPI allreduce would be ~10% faster
Kokkos::fence();
f_out.sync_to_host();
comm->all_reduce(f_out.template get_internal_view_data<ST, Host>(),
l_out.size(), MPI_SUM);
f_out.sync_to_dev();
}
}

template<typename ST>
ST frobenius_norm(const Field& f, const ekat::Comm* comm)
{
Expand Down
110 changes: 110 additions & 0 deletions components/eamxx/src/share/tests/field_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,116 @@ TEST_CASE("utils") {
REQUIRE(field_sum<Real>(f1,&comm)==gsum);
}

SECTION("horiz_contraction") {
using RPDF = std::uniform_real_distribution<Real>;
auto engine = setup_random_test();
RPDF pdf(0, 1);

int dim0 = 3;
int dim1 = 9;
int dim2 = 2;

// Set a weight field
FieldIdentifier f00("f", {{COL}, {dim0}}, m / s, "g");
Field field00(f00);
field00.allocate_view();
field00.sync_to_host();
auto v00 = field00.get_strided_view<Real *, Host>();
for(int i = 0; i < dim0; ++i) {
v00(i) = (i + 1) / sp(6);
}
field00.sync_to_dev();

// Create (random) sample fields
FieldIdentifier fsc("f", {{}, {}}, m / s, "g"); // scalar
FieldIdentifier f10("f", {{COL, CMP}, {dim0, dim1}}, m / s, "g");
FieldIdentifier f11("f", {{COL, LEV}, {dim0, dim2}}, m / s, "g");
FieldIdentifier f20("f", {{COL, CMP, LEV}, {dim0, dim1, dim2}}, m / s, "g");
Field fieldsc(fsc);
Field field10(f10);
Field field11(f11);
Field field20(f20);
fieldsc.allocate_view();
field10.allocate_view();
field11.allocate_view();
field20.allocate_view();
randomize(fieldsc, engine, pdf);
randomize(field10, engine, pdf);
randomize(field11, engine, pdf);
randomize(field20, engine, pdf);

FieldIdentifier F_x("fx", {{COL}, {dim0}}, m / s, "g");
FieldIdentifier F_y("fy", {{LEV}, {dim2}}, m / s, "g");
FieldIdentifier F_z("fz", {{CMP}, {dim1}}, m / s, "g");
FieldIdentifier F_w("fyz", {{CMP, LEV}, {dim1, dim2}}, m / s, "g");

Field field_x(F_x);
Field field_y(F_y);
Field field_z(F_z);
Field field_w(F_w);

// Test invalid inputs
REQUIRE_THROWS(horiz_contraction<Real>(fieldsc, field_x,
field00)); // x not allocated yet

field_x.allocate_view();
field_y.allocate_view();
field_z.allocate_view();
field_w.allocate_view();

REQUIRE_THROWS(horiz_contraction<Real>(fieldsc, field_y,
field_x)); // unmatching layout
REQUIRE_THROWS(horiz_contraction<Real>(field_z, field11,
field11)); // wrong weight layout

Field result;

// Ensure a scalar case works
result = fieldsc.clone();
horiz_contraction<Real>(result, field00, field00);
result.sync_to_host();
auto v = result.get_view<Real, Host>();
REQUIRE(v() == (1 / sp(36) + 4 / sp(36) + 9 / sp(36)));

// Test higher-order cases
result = field_z.clone();
horiz_contraction<Real>(result, field10, field00);
REQUIRE(result.get_header().get_identifier().get_layout().tags() ==
std::vector<FieldTag>({CMP}));
REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim1);

result = field_y.clone();
horiz_contraction<Real>(result, field11, field00);
REQUIRE(result.get_header().get_identifier().get_layout().tags() ==
std::vector<FieldTag>({LEV}));
REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim2);

result = field_w.clone();
horiz_contraction<Real>(result, field20, field00);
REQUIRE(result.get_header().get_identifier().get_layout().tags() ==
std::vector<FieldTag>({CMP, LEV}));
REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim1);
REQUIRE(result.get_header().get_identifier().get_layout().dim(1) == dim2);

// Check a 3D case
field20.sync_to_host();
auto manual_result = result.clone();
manual_result.deep_copy(0);
manual_result.sync_to_host();
auto v2 = field20.get_strided_view<Real ***, Host>();
auto mr = manual_result.get_strided_view<Real **, Host>();
for(int i = 0; i < dim0; ++i) {
for(int j = 0; j < dim1; ++j) {
for(int k = 0; k < dim2; ++k) {
mr(j, k) += v00(i) * v2(i, j, k);
}
}
}
field20.sync_to_dev();
manual_result.sync_to_dev();
REQUIRE(views_are_equal(result, manual_result));
}

SECTION ("frobenius") {

auto v1 = f1.get_strided_view<Real**>();
Expand Down
Loading