Skip to content

Commit 516b710

Browse files
committed
EAMxx: add binary ops diags
1 parent 063b02e commit 516b710

File tree

10 files changed

+360
-9
lines changed

10 files changed

+360
-9
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Binary operations diagnostics
2+
3+
In EAMxx, we can perform binary arithmetic operations on pairs of fields
4+
to create new diagnostic outputs. The binary operations diagnostic allows
5+
you to compute element-wise arithmetic operations between two fields.
6+
7+
## Supported operations
8+
9+
The binary operations diagnostic supports four basic arithmetic operations:
10+
11+
| Operator | Symbol | Description |
12+
| -------- | ------ | ----------- |
13+
| Addition | `+` | Element-wise addition of two fields |
14+
| Subtraction | `-` | Element-wise subtraction of two fields |
15+
| Multiplication | `*` | Element-wise multiplication of two fields |
16+
| Division | `÷` | Element-wise division of two fields |
17+
18+
## Requirements
19+
20+
For two fields to be compatible for binary operations, they must satisfy:
21+
22+
1. **Same layout**
23+
2. **Same data type**
24+
3. **Same grid**
25+
4. **Compatible units** for addition and subtraction
26+
27+
## Unit handling
28+
29+
The resulting diagnostic field will have units determined by the operation.
30+
31+
## Configuration
32+
33+
To use the binary operations diagnostic, one can request an output
34+
using the general syntax of `<field_1>_<binary_op>_<field_2>`:
35+
36+
- `field_1` is the name of the first input field
37+
- `binary_op` is the operator: `+`, `-`, `*`, or `÷`
38+
- `field_2` is the name of the second input field
39+
40+
## Example
41+
42+
```yaml
43+
field_names:
44+
- T_mid_*_p_mid
45+
```
46+
47+
## Caveats
48+
49+
- Strictly speaking, multiple operations will take place in order.
50+
For example, `qc_+_qv_*_p_mid` will first compute `qc + qv`,
51+
then multiply the result by `p_mid`.
52+
- We only support existing fields in the Field Manager, e.g.,
53+
grid information like lat, lon, and area are not available.
54+
- We do not support dimensionality broadcasting, so both fields
55+
must have the same shape.
56+
57+
Contact developers on github if you have additional questions.

components/eamxx/docs/user/diags/index.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,9 @@ EAMxx has facilities to output optional diagnostics
44
that are computed during runtime. These diagnostics
55
are designed generically and composably, and are requestable by users.
66

7+
## Available diagnostics
8+
9+
- [Field contraction](field_contraction.md)
10+
- [Binary arithmetics](binary_ops.md)
11+
712
More details to follow.

components/eamxx/mkdocs.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ nav:
1717
- 'EAMxx runtime parameters': 'user/eamxx_params.md'
1818
- 'Diagnostics':
1919
- 'Overview': 'user/diags/index.md'
20-
- 'Field contraction diagnostics': 'user/diags/field_contraction.md'
20+
- 'Field contraction': 'user/diags/field_contraction.md'
2121
- 'Conditional sampling': 'user/diags/conditional_sampling.md'
22+
- 'Binary arithmetics': 'user/diags/binary_ops.md'
2223
- 'Presentations': 'user/presentations.md'
2324
- 'IO Aliases': 'user/io_aliases.md'
2425
- 'Developer Guide':

components/eamxx/src/diagnostics/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ set(DIAGNOSTIC_SRCS
33
aodvis.cpp
44
atm_backtend.cpp
55
atm_density.cpp
6+
binary_ops.cpp
67
dry_static_energy.cpp
78
exner.cpp
89
field_at_height.cpp
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#include "diagnostics/binary_ops.hpp"
2+
3+
namespace scream {
4+
5+
// parse string to get the operator code
6+
int get_binary_operator_code(const std::string& op) {
7+
if (op == "+") return 0; // addition
8+
if (op == "-") return 1; // subtraction
9+
if (op == "*") return 2; // multiplication
10+
if (op == "÷") return 3; // division
11+
return -1; // invalid operator
12+
}
13+
// apply binary operation on two input units
14+
ekat::units::Units apply_binary_op(const ekat::units::Units& a, const ekat::units::Units& b, const int op_code) {
15+
switch (op_code) {
16+
case 0: // addition - units must be compatible
17+
case 1: // subtraction - units must be compatible
18+
EKAT_REQUIRE_MSG(a == b, "Error! Addition/subtraction requires compatible units.\n");
19+
return a;
20+
case 2: return a * b; // multiplication
21+
case 3: return a / b; // division
22+
default: return a; // no operation, just return a
23+
}
24+
}
25+
// apply binary operation on two input fields
26+
void apply_binary_op(Field& a_clone, const Field &b, const int& op_code){
27+
switch (op_code) {
28+
case 0: return a_clone.update(b, 1, 1); // addition
29+
case 1: return a_clone.update(b, -1, 1); // subtraction
30+
case 2: return a_clone.scale(b); // multiplication
31+
case 3: return a_clone.scale_inv(b); // division
32+
default: return;
33+
}
34+
}
35+
36+
BinaryOpsDiag::
37+
BinaryOpsDiag(const ekat::Comm &comm,
38+
const ekat::ParameterList &params)
39+
: AtmosphereDiagnostic(comm, params)
40+
{
41+
m_field_1 = m_params.get<std::string>("field_1");
42+
m_field_2 = m_params.get<std::string>("field_2");
43+
m_binary_op = m_params.get<std::string>("binary_op");
44+
45+
// Validate operator
46+
EKAT_REQUIRE_MSG(get_binary_operator_code(m_binary_op) >= 0,
47+
"Error! Invalid binary operator: '" + m_binary_op + "'\n"
48+
"Valid operators are: +, -, *, ÷\n");
49+
}
50+
51+
void BinaryOpsDiag::
52+
set_grids(const std::shared_ptr<const GridsManager> grids_manager)
53+
{
54+
const auto &gname = m_params.get<std::string>("grid_name");
55+
add_field<Required>(m_field_1, gname);
56+
add_field<Required>(m_field_2, gname);
57+
}
58+
59+
void BinaryOpsDiag::initialize_impl(const RunType /*run_type*/) {
60+
// get the input fields
61+
const auto &f1 = get_field_in(m_field_1);
62+
const auto &f2 = get_field_in(m_field_2);
63+
64+
const auto &l1 = f1.get_header().get_identifier().get_layout();
65+
const auto &l2 = f2.get_header().get_identifier().get_layout();
66+
// Must be on same layout, same datatype
67+
EKAT_REQUIRE_MSG(
68+
l1 == l2,
69+
"Error! BinaryOpsDiag requires both input fields to have the same layout.\n"
70+
" - field 1 name: " + f1.get_header().get_identifier().name() + "\n"
71+
" - field 1 layout: " + l1.to_string() + "\n"
72+
" - field 2 name: " + f2.get_header().get_identifier().name() + "\n"
73+
" - field 2 layout: " + l2.to_string() + "\n");
74+
EKAT_REQUIRE_MSG(
75+
f1.data_type() == f2.data_type(),
76+
"Error! BinaryOpsDiag requires both input fields to have the same data type.\n"
77+
" - field 1 name: " + f1.get_header().get_identifier().name() + "\n"
78+
" - field 1 data type: " + e2str(f1.data_type()) + "\n"
79+
" - field 2 name: " + f2.get_header().get_identifier().name() + "\n"
80+
" - field 2 data type: " + e2str(f2.data_type()) + "\n");
81+
82+
const auto &gn1 = f1.get_header().get_identifier().get_grid_name();
83+
const auto &gn2 = f2.get_header().get_identifier().get_grid_name();
84+
// Must be on same grid too
85+
EKAT_REQUIRE_MSG(
86+
gn1 == gn2,
87+
"Error! BinaryOpsDiag requires both input fields to be on the same grid.\n"
88+
" - field 1 name: " + f1.get_header().get_identifier().name() + "\n"
89+
" - field 1 grid name: " + gn1 + "\n"
90+
" - field 2 name: " + f2.get_header().get_identifier().name() + "\n"
91+
" - field 2 grid name: " + gn2 + "\n");
92+
93+
const auto &u1 = f1.get_header().get_identifier().get_units();
94+
const auto &u2 = f2.get_header().get_identifier().get_units();
95+
96+
// All good, create the diag output
97+
auto diag_units = apply_binary_op(u1, u2, get_binary_operator_code(m_binary_op));
98+
auto diag_name = m_field_1 + "_" + m_binary_op + "_" + m_field_2;
99+
FieldIdentifier d_fid(diag_name, l1.clone(), diag_units, gn1);
100+
m_diagnostic_output = Field(d_fid);
101+
m_diagnostic_output.allocate_view();
102+
}
103+
104+
void BinaryOpsDiag::compute_diagnostic_impl() {
105+
106+
const auto &f1 = get_field_in(m_field_1);
107+
const auto &f2 = get_field_in(m_field_2);
108+
m_diagnostic_output.deep_copy(f1);
109+
apply_binary_op(m_diagnostic_output, f2, get_binary_operator_code(m_binary_op));
110+
}
111+
112+
} // namespace scream
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#ifndef EAMXX_BINARY_OPS_DIAG_HPP
2+
#define EAMXX_BINARY_OPS_DIAG_HPP
3+
4+
#include "share/atm_process/atmosphere_diagnostic.hpp"
5+
6+
namespace scream {
7+
8+
/*
9+
* This diagnostic will perform binary ops
10+
* like +, -, *, ÷ on two input fields.
11+
*/
12+
13+
class BinaryOpsDiag : public AtmosphereDiagnostic {
14+
public:
15+
// Constructors
16+
BinaryOpsDiag(const ekat::Comm &comm, const ekat::ParameterList &params);
17+
18+
// The name of the diagnostic CLASS (not the computed field)
19+
std::string name() const { return "BinaryOpsDiag"; }
20+
21+
// Set the grid
22+
void set_grids(const std::shared_ptr<const GridsManager> grids_manager);
23+
24+
protected:
25+
#ifdef KOKKOS_ENABLE_CUDA
26+
public:
27+
#endif
28+
void compute_diagnostic_impl();
29+
30+
void initialize_impl(const RunType /*run_type*/) override;
31+
32+
std::string m_field_1;
33+
std::string m_field_2;
34+
std::string m_binary_op;
35+
36+
}; // class BinaryOpsDiag
37+
38+
} // namespace scream
39+
40+
#endif // EAMXX_BINARY_OPS_DIAG_HPP

components/eamxx/src/diagnostics/register_diagnostics.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "diagnostics/vert_contract.hpp"
2929
#include "diagnostics/zonal_avg.hpp"
3030
#include "diagnostics/conditional_sampling.hpp"
31+
#include "diagnostics/binary_ops.hpp"
3132

3233
namespace scream {
3334

@@ -59,6 +60,7 @@ inline void register_diagnostics () {
5960
diag_factory.register_product("VertContractDiag",&create_atmosphere_diagnostic<VertContractDiag>);
6061
diag_factory.register_product("ZonalAvgDiag",&create_atmosphere_diagnostic<ZonalAvgDiag>);
6162
diag_factory.register_product("ConditionalSampling",&create_atmosphere_diagnostic<ConditionalSampling>);
63+
diag_factory.register_product("BinaryOpsDiag", &create_atmosphere_diagnostic<BinaryOpsDiag>);
6264
}
6365

6466
} // namespace scream

components/eamxx/src/diagnostics/tests/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,6 @@ CreateDiagTest(zonal_avg zonal_avg_test.cpp MPI_RANKS 1 ${SCREAM_TEST_MAX_RANKS}
8383

8484
# Test conditional sampling
8585
CreateDiagTest(conditional_sampling "conditional_sampling_test.cpp")
86+
87+
# Test binary ops
88+
CreateDiagTest(binary_ops "binary_ops_test.cpp")
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#include "catch2/catch.hpp"
2+
#include "diagnostics/register_diagnostics.hpp"
3+
#include "share/field/field_utils.hpp"
4+
#include "share/grid/mesh_free_grids_manager.hpp"
5+
#include "share/util/eamxx_setup_random_test.hpp"
6+
#include "share/util/eamxx_universal_constants.hpp"
7+
8+
namespace scream {
9+
10+
std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols,
11+
const int nlevs) {
12+
const int num_global_cols = ncols * comm.size();
13+
14+
using vos_t = std::vector<std::string>;
15+
ekat::ParameterList gm_params;
16+
gm_params.set("grids_names", vos_t{"point_grid"});
17+
auto &pl = gm_params.sublist("point_grid");
18+
pl.set<std::string>("type", "point_grid");
19+
pl.set("aliases", vos_t{"physics"});
20+
pl.set<int>("number_of_global_columns", num_global_cols);
21+
pl.set<int>("number_of_vertical_levels", nlevs);
22+
23+
auto gm = create_mesh_free_grids_manager(comm, gm_params);
24+
gm->build_grids();
25+
26+
return gm;
27+
}
28+
29+
TEST_CASE("binary_ops") {
30+
using namespace ShortFieldTagsNames;
31+
using namespace ekat::units;
32+
33+
// A world comm
34+
ekat::Comm comm(MPI_COMM_WORLD);
35+
36+
// A time stamp
37+
util::TimeStamp t0({2024, 1, 1}, {0, 0, 0});
38+
39+
// Create a grids manager - single column for these tests
40+
constexpr int nlevs = 201;
41+
const int ngcols = 260 * comm.size();
42+
43+
auto gm = create_gm(comm, ngcols, nlevs);
44+
auto grid = gm->get_grid("physics");
45+
46+
// Input (randomized) qc, qv
47+
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}};
48+
FieldIdentifier qc_fid("qc", scalar2d_layout, kg / kg, grid->name());
49+
FieldIdentifier qv_fid("qv", scalar2d_layout, kg / kg, grid->name());
50+
51+
Field qc(qc_fid);
52+
Field qv(qv_fid);
53+
qc.allocate_view();
54+
qv.allocate_view();
55+
56+
// Construct random number generator stuff
57+
using RPDF = std::uniform_real_distribution<Real>;
58+
RPDF pdf(0.0, 200.0);
59+
60+
auto engine = scream::setup_random_test();
61+
62+
// Construct the Diagnostics
63+
std::map<std::string, std::shared_ptr<AtmosphereDiagnostic>> diags;
64+
auto &diag_factory = AtmosphereDiagnosticFactory::instance();
65+
register_diagnostics();
66+
67+
ekat::ParameterList params;
68+
REQUIRE_THROWS(diag_factory.create("BinaryOpsDiag", comm,
69+
params)); // No 'field_1', 'field_2', or 'binary_op'
70+
71+
// Set time for qc and randomize its values
72+
qc.get_header().get_tracking().update_time_stamp(t0);
73+
qv.get_header().get_tracking().update_time_stamp(t0);
74+
randomize(qc, engine, pdf); qc.sync_to_dev();
75+
randomize(qv, engine, pdf); qv.sync_to_dev();
76+
77+
// Create and set up the diagnostic
78+
params.set("grid_name", grid->name());
79+
params.set<std::string>("field_1", "qc");
80+
params.set<std::string>("field_2", "qv");
81+
params.set<std::string>("binary_op", "+");
82+
auto plus_diag = diag_factory.create("BinaryOpsDiag", comm, params);
83+
params.set<std::string>("binary_op", "*");
84+
auto prod_diag = diag_factory.create("BinaryOpsDiag", comm, params);
85+
plus_diag->set_grids(gm);
86+
prod_diag->set_grids(gm);
87+
plus_diag->set_required_field(qc);
88+
prod_diag->set_required_field(qc);
89+
plus_diag->set_required_field(qv);
90+
prod_diag->set_required_field(qv);
91+
plus_diag->initialize(t0, RunType::Initial);
92+
prod_diag->initialize(t0, RunType::Initial);
93+
94+
// Run diag
95+
plus_diag->compute_diagnostic();
96+
auto plus_diag_f = plus_diag->get_diagnostic(); plus_diag_f.sync_to_host();
97+
prod_diag->compute_diagnostic();
98+
auto prod_diag_f = prod_diag->get_diagnostic(); prod_diag_f.sync_to_host();
99+
100+
// Check that the output fields have the right values
101+
const auto &plus_v = plus_diag_f.get_view<Real**, Host>();
102+
const auto &prod_v = prod_diag_f.get_view<Real**, Host>();
103+
const auto &qc_v = qc.get_view<Real**, Host>();
104+
const auto &qv_v = qv.get_view<Real**, Host>();
105+
for (int icol = 0; icol < ngcols; ++icol) {
106+
for (int ilev = 0; ilev < nlevs; ++ilev) {
107+
// Check plus
108+
REQUIRE(plus_v(icol, ilev) == qc_v(icol, ilev) + qv_v(icol, ilev));
109+
// Check product
110+
REQUIRE(prod_v(icol, ilev) == qc_v(icol, ilev) * qv_v(icol, ilev));
111+
}
112+
}
113+
114+
// redundant, why not
115+
qc.update(qv, 1, 1);
116+
views_are_equal(qc, plus_diag_f);
117+
}
118+
119+
} // namespace scream

0 commit comments

Comments
 (0)