Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions components/eamxx/docs/user/diags/binary_ops.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Binary operations diagnostics

In EAMxx, we can perform binary arithmetic operations on pairs of fields
to create new diagnostic outputs. The binary operations diagnostic allows
you to compute element-wise arithmetic operations between two fields.

## Supported operations

The binary operations diagnostic supports four basic arithmetic operations:

| Operator | Symbol | Description |
| -------- | ------ | ----------- |
| Addition | `plus` | Element-wise addition of two fields |
| Subtraction | `minus` | Element-wise subtraction of two fields |
| Multiplication | `times` | Element-wise multiplication of two fields |
| Division | `over` | Element-wise division of two fields |

## Requirements

For two fields to be compatible for binary operations, they must satisfy:

1. **Same layout**
2. **Same data type**
3. **Same grid**
4. **Compatible units** for addition and subtraction

## Unit handling

The resulting diagnostic field will have units determined by the operation.

## Configuration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that it is available, should we remind the user about aliasing, and "encourage" to use it for binary-op diags?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, I would keep the docs for each feature separate. I will write an overarching docs page for tips to use them together, but I think each page should be isolated to what it is offering as a standalone feature


To use the binary operations diagnostic, one can request an output
using the general syntax of `<field_1>_<binary_op>_<field_2>`:

- `field_1` is the name of the first input field
- `binary_op` is the operator: `plus`, `minus`, `times`, or `over`
- `field_2` is the name of the second input field

## Example

```yaml
field_names:
# T_mid has units K
- T_mid_times_p_mid # K*Pa
- T_mid_over_p_mid # K/Pa
- T_mid_plus_T_mid # K
- T_mid_minus_T_mid # K
```

## Caveats

- As the name suggests, these diagnostics were written for a single
operation connecting two fields; doing multiple binary ops in
succession carries risks; read on.
- Strictly speaking, multiple operations will take place in order.
Copy link
Contributor

@AaronDonahue AaronDonahue Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we note that we don't support parentheses? So no _(_qc_+_qv_)_*_p_mid

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second question, do we have a test for this case? I may have missed it, but I think the test just does binary ops for 2 vars. It seems like 3 vars is a good edge case, and sense you mention it as an example it would be good to test it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, both are good suggestions. I can include them or in a follow-up PR (that will combine docs for all features).

I will also need to clarify very clearly the order of ops by giving a few examples...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AaronDonahue, I will add the composability test later. The composability comes from the parsing (in eamxx IO utils), not the diag impl itself. This PR is mostly about the diag impl. I added some notes, but I will strengthen the testing later.

For example, `qc_plus_qv_times_p_mid` will first compute `qc + qv`,
then multiply the result by `p_mid`. (Beware, we do NOT support
mathematical operation precedence; it is simply about the parser
which processes the expression left-to-right.)
- In fact, `p_mid_times_qc_plus_qv` will fail because of units!
- We only support existing fields in the Field Manager, e.g.,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe

We only support existing fields in the Field Manager (and other diagnostics). e.g. ...

grid information like lat, lon, and area are not available.
- We do not support dimensionality broadcasting, so both fields
must have the same shape.

Contact developers on github if you have additional questions.
5 changes: 5 additions & 0 deletions components/eamxx/docs/user/diags/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,9 @@ EAMxx has facilities to output optional diagnostics
that are computed during runtime. These diagnostics
are designed generically and composably, and are requestable by users.

## Available diagnostics

- [Field contraction](field_contraction.md)
- [Binary arithmetics](binary_ops.md)

More details to follow.
3 changes: 2 additions & 1 deletion components/eamxx/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ nav:
- 'EAMxx runtime parameters': 'user/eamxx_params.md'
- 'Diagnostics':
- 'Overview': 'user/diags/index.md'
- 'Field contraction diagnostics': 'user/diags/field_contraction.md'
- 'Field contraction': 'user/diags/field_contraction.md'
- 'Conditional sampling': 'user/diags/conditional_sampling.md'
- 'Binary arithmetics': 'user/diags/binary_ops.md'
- 'Presentations': 'user/presentations.md'
- 'IO Aliases': 'user/io_aliases.md'
- 'Developer Guide':
Expand Down
1 change: 1 addition & 0 deletions components/eamxx/src/diagnostics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set(DIAGNOSTIC_SRCS
aodvis.cpp
atm_backtend.cpp
atm_density.cpp
binary_ops.cpp
dry_static_energy.cpp
exner.cpp
field_at_height.cpp
Expand Down
112 changes: 112 additions & 0 deletions components/eamxx/src/diagnostics/binary_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include "diagnostics/binary_ops.hpp"

namespace scream {

// parse string to get the operator code
int get_binary_operator_code(const std::string& op) {
if (op == "plus") return 0; // addition
if (op == "minus") return 1; // subtraction
if (op == "times") return 2; // multiplication
if (op == "over") return 3; // division
return -1; // invalid operator
}
// apply binary operation on two input units
ekat::units::Units apply_binary_op(const ekat::units::Units& a, const ekat::units::Units& b, const int op_code) {
switch (op_code) {
case 0: // addition - units must be compatible
case 1: // subtraction - units must be compatible
EKAT_REQUIRE_MSG(a == b, "Error! Addition/subtraction requires compatible units.\n");
return a;
case 2: return a * b; // multiplication
case 3: return a / b; // division
default: return a; // no operation, just return a
}
}
// apply binary operation on two input fields
void apply_binary_op(Field& a_clone, const Field &b, const int& op_code){
switch (op_code) {
case 0: return a_clone.update(b, 1, 1); // addition
case 1: return a_clone.update(b, -1, 1); // subtraction
case 2: return a_clone.scale(b); // multiplication
case 3: return a_clone.scale_inv(b); // division
default: return;
}
}

BinaryOpsDiag::
BinaryOpsDiag(const ekat::Comm &comm,
const ekat::ParameterList &params)
: AtmosphereDiagnostic(comm, params)
{
m_field_1 = m_params.get<std::string>("field_1");
m_field_2 = m_params.get<std::string>("field_2");
m_binary_op = m_params.get<std::string>("binary_op");

// Validate operator
EKAT_REQUIRE_MSG(get_binary_operator_code(m_binary_op) >= 0,
"Error! Invalid binary operator: '" + m_binary_op + "'\n"
"Valid operators are: plus, minus, times, over\n");
}

void BinaryOpsDiag::
set_grids(const std::shared_ptr<const GridsManager> grids_manager)
{
const auto &gname = m_params.get<std::string>("grid_name");
add_field<Required>(m_field_1, gname);
add_field<Required>(m_field_2, gname);
}

void BinaryOpsDiag::initialize_impl(const RunType /*run_type*/) {
// get the input fields
const auto &f1 = get_field_in(m_field_1);
const auto &f2 = get_field_in(m_field_2);

const auto &l1 = f1.get_header().get_identifier().get_layout();
const auto &l2 = f2.get_header().get_identifier().get_layout();
// Must be on same layout, same datatype
EKAT_REQUIRE_MSG(
l1 == l2,
"Error! BinaryOpsDiag requires both input fields to have the same layout.\n"
" - field 1 name: " + f1.get_header().get_identifier().name() + "\n"
" - field 1 layout: " + l1.to_string() + "\n"
" - field 2 name: " + f2.get_header().get_identifier().name() + "\n"
" - field 2 layout: " + l2.to_string() + "\n");
EKAT_REQUIRE_MSG(
f1.data_type() == f2.data_type(),
"Error! BinaryOpsDiag requires both input fields to have the same data type.\n"
" - field 1 name: " + f1.get_header().get_identifier().name() + "\n"
" - field 1 data type: " + e2str(f1.data_type()) + "\n"
" - field 2 name: " + f2.get_header().get_identifier().name() + "\n"
" - field 2 data type: " + e2str(f2.data_type()) + "\n");

const auto &gn1 = f1.get_header().get_identifier().get_grid_name();
const auto &gn2 = f2.get_header().get_identifier().get_grid_name();
// Must be on same grid too
EKAT_REQUIRE_MSG(
gn1 == gn2,
"Error! BinaryOpsDiag requires both input fields to be on the same grid.\n"
" - field 1 name: " + f1.get_header().get_identifier().name() + "\n"
" - field 1 grid name: " + gn1 + "\n"
" - field 2 name: " + f2.get_header().get_identifier().name() + "\n"
" - field 2 grid name: " + gn2 + "\n");

const auto &u1 = f1.get_header().get_identifier().get_units();
const auto &u2 = f2.get_header().get_identifier().get_units();

// All good, create the diag output
auto diag_units = apply_binary_op(u1, u2, get_binary_operator_code(m_binary_op));
auto diag_name = m_field_1 + "_" + m_binary_op + "_" + m_field_2;
FieldIdentifier d_fid(diag_name, l1.clone(), diag_units, gn1);
m_diagnostic_output = Field(d_fid);
m_diagnostic_output.allocate_view();
}

void BinaryOpsDiag::compute_diagnostic_impl() {

const auto &f1 = get_field_in(m_field_1);
const auto &f2 = get_field_in(m_field_2);
m_diagnostic_output.deep_copy(f1);
apply_binary_op(m_diagnostic_output, f2, get_binary_operator_code(m_binary_op));
}

} // namespace scream
40 changes: 40 additions & 0 deletions components/eamxx/src/diagnostics/binary_ops.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef EAMXX_BINARY_OPS_DIAG_HPP
#define EAMXX_BINARY_OPS_DIAG_HPP

#include "share/atm_process/atmosphere_diagnostic.hpp"

namespace scream {

/*
* This diagnostic will perform binary ops
* like +, -, *, ÷ on two input fields.
*/

class BinaryOpsDiag : public AtmosphereDiagnostic {
public:
// Constructors
BinaryOpsDiag(const ekat::Comm &comm, const ekat::ParameterList &params);

// The name of the diagnostic CLASS (not the computed field)
std::string name() const { return "BinaryOpsDiag"; }

// Set the grid
void set_grids(const std::shared_ptr<const GridsManager> grids_manager);

protected:
#ifdef KOKKOS_ENABLE_CUDA
public:
#endif
void compute_diagnostic_impl();

void initialize_impl(const RunType /*run_type*/) override;

std::string m_field_1;
std::string m_field_2;
std::string m_binary_op;

}; // class BinaryOpsDiag

} // namespace scream

#endif // EAMXX_BINARY_OPS_DIAG_HPP
2 changes: 2 additions & 0 deletions components/eamxx/src/diagnostics/register_diagnostics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "diagnostics/vert_contract.hpp"
#include "diagnostics/zonal_avg.hpp"
#include "diagnostics/conditional_sampling.hpp"
#include "diagnostics/binary_ops.hpp"

namespace scream {

Expand Down Expand Up @@ -59,6 +60,7 @@ inline void register_diagnostics () {
diag_factory.register_product("VertContractDiag",&create_atmosphere_diagnostic<VertContractDiag>);
diag_factory.register_product("ZonalAvgDiag",&create_atmosphere_diagnostic<ZonalAvgDiag>);
diag_factory.register_product("ConditionalSampling",&create_atmosphere_diagnostic<ConditionalSampling>);
diag_factory.register_product("BinaryOpsDiag", &create_atmosphere_diagnostic<BinaryOpsDiag>);
}

} // namespace scream
Expand Down
3 changes: 3 additions & 0 deletions components/eamxx/src/diagnostics/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,6 @@ CreateDiagTest(zonal_avg zonal_avg_test.cpp MPI_RANKS 1 ${SCREAM_TEST_MAX_RANKS}

# Test conditional sampling
CreateDiagTest(conditional_sampling "conditional_sampling_test.cpp")

# Test binary ops
CreateDiagTest(binary_ops "binary_ops_test.cpp")
119 changes: 119 additions & 0 deletions components/eamxx/src/diagnostics/tests/binary_ops_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "catch2/catch.hpp"
#include "diagnostics/register_diagnostics.hpp"
#include "share/field/field_utils.hpp"
#include "share/grid/mesh_free_grids_manager.hpp"
#include "share/util/eamxx_setup_random_test.hpp"
#include "share/util/eamxx_universal_constants.hpp"

namespace scream {

std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols,
const int nlevs) {
const int num_global_cols = ncols * comm.size();

using vos_t = std::vector<std::string>;
ekat::ParameterList gm_params;
gm_params.set("grids_names", vos_t{"point_grid"});
auto &pl = gm_params.sublist("point_grid");
pl.set<std::string>("type", "point_grid");
pl.set("aliases", vos_t{"physics"});
pl.set<int>("number_of_global_columns", num_global_cols);
pl.set<int>("number_of_vertical_levels", nlevs);

auto gm = create_mesh_free_grids_manager(comm, gm_params);
gm->build_grids();

return gm;
}

TEST_CASE("binary_ops") {
using namespace ShortFieldTagsNames;
using namespace ekat::units;

// A world comm
ekat::Comm comm(MPI_COMM_WORLD);

// A time stamp
util::TimeStamp t0({2024, 1, 1}, {0, 0, 0});

// Create a grids manager - single column for these tests
constexpr int nlevs = 201;
const int ngcols = 260 * comm.size();

auto gm = create_gm(comm, ngcols, nlevs);
auto grid = gm->get_grid("physics");

// Input (randomized) qc, qv
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}};
FieldIdentifier qc_fid("qc", scalar2d_layout, kg / kg, grid->name());
FieldIdentifier qv_fid("qv", scalar2d_layout, kg / kg, grid->name());

Field qc(qc_fid);
Field qv(qv_fid);
qc.allocate_view();
qv.allocate_view();

// Construct random number generator stuff
using RPDF = std::uniform_real_distribution<Real>;
RPDF pdf(0.0, 200.0);

auto engine = scream::setup_random_test();

// Construct the Diagnostics
std::map<std::string, std::shared_ptr<AtmosphereDiagnostic>> diags;
auto &diag_factory = AtmosphereDiagnosticFactory::instance();
register_diagnostics();

ekat::ParameterList params;
REQUIRE_THROWS(diag_factory.create("BinaryOpsDiag", comm,
params)); // No 'field_1', 'field_2', or 'binary_op'

// Set time for qc and randomize its values
qc.get_header().get_tracking().update_time_stamp(t0);
qv.get_header().get_tracking().update_time_stamp(t0);
randomize(qc, engine, pdf); qc.sync_to_dev();
randomize(qv, engine, pdf); qv.sync_to_dev();

// Create and set up the diagnostic
params.set("grid_name", grid->name());
params.set<std::string>("field_1", "qc");
params.set<std::string>("field_2", "qv");
params.set<std::string>("binary_op", "plus");
auto plus_diag = diag_factory.create("BinaryOpsDiag", comm, params);
params.set<std::string>("binary_op", "times");
auto prod_diag = diag_factory.create("BinaryOpsDiag", comm, params);
plus_diag->set_grids(gm);
prod_diag->set_grids(gm);
plus_diag->set_required_field(qc);
prod_diag->set_required_field(qc);
plus_diag->set_required_field(qv);
prod_diag->set_required_field(qv);
plus_diag->initialize(t0, RunType::Initial);
prod_diag->initialize(t0, RunType::Initial);

// Run diag
plus_diag->compute_diagnostic();
auto plus_diag_f = plus_diag->get_diagnostic(); plus_diag_f.sync_to_host();
prod_diag->compute_diagnostic();
auto prod_diag_f = prod_diag->get_diagnostic(); prod_diag_f.sync_to_host();

// Check that the output fields have the right values
const auto &plus_v = plus_diag_f.get_view<Real**, Host>();
const auto &prod_v = prod_diag_f.get_view<Real**, Host>();
const auto &qc_v = qc.get_view<Real**, Host>();
const auto &qv_v = qv.get_view<Real**, Host>();
for (int icol = 0; icol < ngcols; ++icol) {
for (int ilev = 0; ilev < nlevs; ++ilev) {
// Check plus
REQUIRE(plus_v(icol, ilev) == qc_v(icol, ilev) + qv_v(icol, ilev));
// Check product
REQUIRE(prod_v(icol, ilev) == qc_v(icol, ilev) * qv_v(icol, ilev));
}
}

// redundant, why not
qc.update(qv, 1, 1);
views_are_equal(qc, plus_diag_f);
}

} // namespace scream
Loading