-
Notifications
You must be signed in to change notification settings - Fork 435
EAMxx: add binary ops diags #7573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Binary operations diagnostics | ||
|
||
In EAMxx, we can perform binary arithmetic operations on pairs of fields | ||
to create new diagnostic outputs. The binary operations diagnostic allows | ||
you to compute element-wise arithmetic operations between two fields. | ||
|
||
## Supported operations | ||
|
||
The binary operations diagnostic supports four basic arithmetic operations: | ||
|
||
| Operator | Symbol | Description | | ||
| -------- | ------ | ----------- | | ||
| Addition | `plus` | Element-wise addition of two fields | | ||
| Subtraction | `minus` | Element-wise subtraction of two fields | | ||
| Multiplication | `times` | Element-wise multiplication of two fields | | ||
| Division | `over` | Element-wise division of two fields | | ||
|
||
## Requirements | ||
|
||
For two fields to be compatible for binary operations, they must satisfy: | ||
|
||
1. **Same layout** | ||
2. **Same data type** | ||
3. **Same grid** | ||
4. **Compatible units** for addition and subtraction | ||
|
||
## Unit handling | ||
|
||
The resulting diagnostic field will have units determined by the operation. | ||
|
||
## Configuration | ||
|
||
To use the binary operations diagnostic, one can request an output | ||
using the general syntax of `<field_1>_<binary_op>_<field_2>`: | ||
|
||
- `field_1` is the name of the first input field | ||
- `binary_op` is the operator: `plus`, `minus`, `times`, or `over` | ||
- `field_2` is the name of the second input field | ||
|
||
## Example | ||
|
||
```yaml | ||
field_names: | ||
# T_mid has units K | ||
- T_mid_times_p_mid # K*Pa | ||
- T_mid_over_p_mid # K/Pa | ||
- T_mid_plus_T_mid # K | ||
- T_mid_minus_T_mid # K | ||
``` | ||
|
||
## Caveats | ||
|
||
- As the name suggests, these diagnostics were written for a single | ||
operation connecting two fields; doing multiple binary ops in | ||
succession carries risks; read on. | ||
- Strictly speaking, multiple operations will take place in order. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we note that we don't support parentheses? So no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Second question, do we have a test for this case? I may have missed it, but I think the test just does binary ops for 2 vars. It seems like 3 vars is a good edge case, and sense you mention it as an example it would be good to test it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, both are good suggestions. I can include them or in a follow-up PR (that will combine docs for all features). I will also need to clarify very clearly the order of ops by giving a few examples... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AaronDonahue, I will add the composability test later. The composability comes from the parsing (in eamxx IO utils), not the diag impl itself. This PR is mostly about the diag impl. I added some notes, but I will strengthen the testing later. |
||
For example, `qc_plus_qv_times_p_mid` will first compute `qc + qv`, | ||
then multiply the result by `p_mid`. (Beware, we do NOT support | ||
mathematical operation precedence; it is simply about the parser | ||
which processes the expression left-to-right.) | ||
- In fact, `p_mid_times_qc_plus_qv` will fail because of units! | ||
- We only support existing fields in the Field Manager, e.g., | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe We only support existing fields in the Field Manager (and other diagnostics). e.g. ... |
||
grid information like lat, lon, and area are not available. | ||
- We do not support dimensionality broadcasting, so both fields | ||
must have the same shape. | ||
|
||
Contact developers on github if you have additional questions. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
#include "diagnostics/binary_ops.hpp" | ||
|
||
namespace scream { | ||
|
||
// parse string to get the operator code | ||
int get_binary_operator_code(const std::string& op) { | ||
if (op == "plus") return 0; // addition | ||
if (op == "minus") return 1; // subtraction | ||
if (op == "times") return 2; // multiplication | ||
if (op == "over") return 3; // division | ||
return -1; // invalid operator | ||
} | ||
// apply binary operation on two input units | ||
ekat::units::Units apply_binary_op(const ekat::units::Units& a, const ekat::units::Units& b, const int op_code) { | ||
switch (op_code) { | ||
case 0: // addition - units must be compatible | ||
case 1: // subtraction - units must be compatible | ||
EKAT_REQUIRE_MSG(a == b, "Error! Addition/subtraction requires compatible units.\n"); | ||
return a; | ||
case 2: return a * b; // multiplication | ||
case 3: return a / b; // division | ||
default: return a; // no operation, just return a | ||
} | ||
} | ||
// apply binary operation on two input fields | ||
void apply_binary_op(Field& a_clone, const Field &b, const int& op_code){ | ||
switch (op_code) { | ||
case 0: return a_clone.update(b, 1, 1); // addition | ||
case 1: return a_clone.update(b, -1, 1); // subtraction | ||
case 2: return a_clone.scale(b); // multiplication | ||
case 3: return a_clone.scale_inv(b); // division | ||
default: return; | ||
} | ||
} | ||
|
||
BinaryOpsDiag:: | ||
BinaryOpsDiag(const ekat::Comm &comm, | ||
const ekat::ParameterList ¶ms) | ||
: AtmosphereDiagnostic(comm, params) | ||
{ | ||
m_field_1 = m_params.get<std::string>("field_1"); | ||
m_field_2 = m_params.get<std::string>("field_2"); | ||
m_binary_op = m_params.get<std::string>("binary_op"); | ||
|
||
// Validate operator | ||
EKAT_REQUIRE_MSG(get_binary_operator_code(m_binary_op) >= 0, | ||
"Error! Invalid binary operator: '" + m_binary_op + "'\n" | ||
"Valid operators are: plus, minus, times, over\n"); | ||
} | ||
|
||
void BinaryOpsDiag:: | ||
set_grids(const std::shared_ptr<const GridsManager> grids_manager) | ||
{ | ||
const auto &gname = m_params.get<std::string>("grid_name"); | ||
add_field<Required>(m_field_1, gname); | ||
add_field<Required>(m_field_2, gname); | ||
} | ||
|
||
void BinaryOpsDiag::initialize_impl(const RunType /*run_type*/) { | ||
// get the input fields | ||
const auto &f1 = get_field_in(m_field_1); | ||
const auto &f2 = get_field_in(m_field_2); | ||
|
||
const auto &l1 = f1.get_header().get_identifier().get_layout(); | ||
const auto &l2 = f2.get_header().get_identifier().get_layout(); | ||
// Must be on same layout, same datatype | ||
EKAT_REQUIRE_MSG( | ||
l1 == l2, | ||
"Error! BinaryOpsDiag requires both input fields to have the same layout.\n" | ||
" - field 1 name: " + f1.get_header().get_identifier().name() + "\n" | ||
" - field 1 layout: " + l1.to_string() + "\n" | ||
" - field 2 name: " + f2.get_header().get_identifier().name() + "\n" | ||
" - field 2 layout: " + l2.to_string() + "\n"); | ||
EKAT_REQUIRE_MSG( | ||
f1.data_type() == f2.data_type(), | ||
"Error! BinaryOpsDiag requires both input fields to have the same data type.\n" | ||
" - field 1 name: " + f1.get_header().get_identifier().name() + "\n" | ||
" - field 1 data type: " + e2str(f1.data_type()) + "\n" | ||
" - field 2 name: " + f2.get_header().get_identifier().name() + "\n" | ||
" - field 2 data type: " + e2str(f2.data_type()) + "\n"); | ||
|
||
const auto &gn1 = f1.get_header().get_identifier().get_grid_name(); | ||
const auto &gn2 = f2.get_header().get_identifier().get_grid_name(); | ||
// Must be on same grid too | ||
EKAT_REQUIRE_MSG( | ||
gn1 == gn2, | ||
"Error! BinaryOpsDiag requires both input fields to be on the same grid.\n" | ||
" - field 1 name: " + f1.get_header().get_identifier().name() + "\n" | ||
" - field 1 grid name: " + gn1 + "\n" | ||
" - field 2 name: " + f2.get_header().get_identifier().name() + "\n" | ||
" - field 2 grid name: " + gn2 + "\n"); | ||
|
||
const auto &u1 = f1.get_header().get_identifier().get_units(); | ||
const auto &u2 = f2.get_header().get_identifier().get_units(); | ||
|
||
// All good, create the diag output | ||
auto diag_units = apply_binary_op(u1, u2, get_binary_operator_code(m_binary_op)); | ||
auto diag_name = m_field_1 + "_" + m_binary_op + "_" + m_field_2; | ||
FieldIdentifier d_fid(diag_name, l1.clone(), diag_units, gn1); | ||
m_diagnostic_output = Field(d_fid); | ||
m_diagnostic_output.allocate_view(); | ||
} | ||
|
||
void BinaryOpsDiag::compute_diagnostic_impl() { | ||
|
||
const auto &f1 = get_field_in(m_field_1); | ||
const auto &f2 = get_field_in(m_field_2); | ||
m_diagnostic_output.deep_copy(f1); | ||
apply_binary_op(m_diagnostic_output, f2, get_binary_operator_code(m_binary_op)); | ||
} | ||
|
||
} // namespace scream |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef EAMXX_BINARY_OPS_DIAG_HPP | ||
#define EAMXX_BINARY_OPS_DIAG_HPP | ||
|
||
#include "share/atm_process/atmosphere_diagnostic.hpp" | ||
|
||
namespace scream { | ||
|
||
/* | ||
* This diagnostic will perform binary ops | ||
* like +, -, *, ÷ on two input fields. | ||
*/ | ||
|
||
class BinaryOpsDiag : public AtmosphereDiagnostic { | ||
public: | ||
// Constructors | ||
BinaryOpsDiag(const ekat::Comm &comm, const ekat::ParameterList ¶ms); | ||
|
||
// The name of the diagnostic CLASS (not the computed field) | ||
std::string name() const { return "BinaryOpsDiag"; } | ||
|
||
// Set the grid | ||
void set_grids(const std::shared_ptr<const GridsManager> grids_manager); | ||
|
||
protected: | ||
#ifdef KOKKOS_ENABLE_CUDA | ||
public: | ||
#endif | ||
void compute_diagnostic_impl(); | ||
|
||
void initialize_impl(const RunType /*run_type*/) override; | ||
|
||
std::string m_field_1; | ||
std::string m_field_2; | ||
std::string m_binary_op; | ||
|
||
}; // class BinaryOpsDiag | ||
|
||
} // namespace scream | ||
|
||
#endif // EAMXX_BINARY_OPS_DIAG_HPP |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#include "catch2/catch.hpp" | ||
#include "diagnostics/register_diagnostics.hpp" | ||
#include "share/field/field_utils.hpp" | ||
#include "share/grid/mesh_free_grids_manager.hpp" | ||
#include "share/util/eamxx_setup_random_test.hpp" | ||
#include "share/util/eamxx_universal_constants.hpp" | ||
|
||
namespace scream { | ||
|
||
std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols, | ||
const int nlevs) { | ||
const int num_global_cols = ncols * comm.size(); | ||
|
||
using vos_t = std::vector<std::string>; | ||
ekat::ParameterList gm_params; | ||
gm_params.set("grids_names", vos_t{"point_grid"}); | ||
auto &pl = gm_params.sublist("point_grid"); | ||
pl.set<std::string>("type", "point_grid"); | ||
pl.set("aliases", vos_t{"physics"}); | ||
pl.set<int>("number_of_global_columns", num_global_cols); | ||
pl.set<int>("number_of_vertical_levels", nlevs); | ||
|
||
auto gm = create_mesh_free_grids_manager(comm, gm_params); | ||
gm->build_grids(); | ||
|
||
return gm; | ||
} | ||
|
||
TEST_CASE("binary_ops") { | ||
using namespace ShortFieldTagsNames; | ||
using namespace ekat::units; | ||
|
||
// A world comm | ||
ekat::Comm comm(MPI_COMM_WORLD); | ||
|
||
// A time stamp | ||
util::TimeStamp t0({2024, 1, 1}, {0, 0, 0}); | ||
|
||
// Create a grids manager - single column for these tests | ||
constexpr int nlevs = 201; | ||
const int ngcols = 260 * comm.size(); | ||
|
||
auto gm = create_gm(comm, ngcols, nlevs); | ||
auto grid = gm->get_grid("physics"); | ||
|
||
// Input (randomized) qc, qv | ||
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}}; | ||
FieldIdentifier qc_fid("qc", scalar2d_layout, kg / kg, grid->name()); | ||
FieldIdentifier qv_fid("qv", scalar2d_layout, kg / kg, grid->name()); | ||
|
||
Field qc(qc_fid); | ||
Field qv(qv_fid); | ||
qc.allocate_view(); | ||
qv.allocate_view(); | ||
|
||
// Construct random number generator stuff | ||
using RPDF = std::uniform_real_distribution<Real>; | ||
RPDF pdf(0.0, 200.0); | ||
|
||
auto engine = scream::setup_random_test(); | ||
|
||
// Construct the Diagnostics | ||
std::map<std::string, std::shared_ptr<AtmosphereDiagnostic>> diags; | ||
auto &diag_factory = AtmosphereDiagnosticFactory::instance(); | ||
register_diagnostics(); | ||
|
||
ekat::ParameterList params; | ||
REQUIRE_THROWS(diag_factory.create("BinaryOpsDiag", comm, | ||
params)); // No 'field_1', 'field_2', or 'binary_op' | ||
|
||
// Set time for qc and randomize its values | ||
qc.get_header().get_tracking().update_time_stamp(t0); | ||
qv.get_header().get_tracking().update_time_stamp(t0); | ||
randomize(qc, engine, pdf); qc.sync_to_dev(); | ||
randomize(qv, engine, pdf); qv.sync_to_dev(); | ||
|
||
// Create and set up the diagnostic | ||
params.set("grid_name", grid->name()); | ||
params.set<std::string>("field_1", "qc"); | ||
params.set<std::string>("field_2", "qv"); | ||
params.set<std::string>("binary_op", "plus"); | ||
auto plus_diag = diag_factory.create("BinaryOpsDiag", comm, params); | ||
params.set<std::string>("binary_op", "times"); | ||
auto prod_diag = diag_factory.create("BinaryOpsDiag", comm, params); | ||
plus_diag->set_grids(gm); | ||
prod_diag->set_grids(gm); | ||
plus_diag->set_required_field(qc); | ||
prod_diag->set_required_field(qc); | ||
plus_diag->set_required_field(qv); | ||
prod_diag->set_required_field(qv); | ||
plus_diag->initialize(t0, RunType::Initial); | ||
prod_diag->initialize(t0, RunType::Initial); | ||
|
||
// Run diag | ||
plus_diag->compute_diagnostic(); | ||
auto plus_diag_f = plus_diag->get_diagnostic(); plus_diag_f.sync_to_host(); | ||
prod_diag->compute_diagnostic(); | ||
auto prod_diag_f = prod_diag->get_diagnostic(); prod_diag_f.sync_to_host(); | ||
|
||
// Check that the output fields have the right values | ||
const auto &plus_v = plus_diag_f.get_view<Real**, Host>(); | ||
const auto &prod_v = prod_diag_f.get_view<Real**, Host>(); | ||
const auto &qc_v = qc.get_view<Real**, Host>(); | ||
const auto &qv_v = qv.get_view<Real**, Host>(); | ||
for (int icol = 0; icol < ngcols; ++icol) { | ||
for (int ilev = 0; ilev < nlevs; ++ilev) { | ||
// Check plus | ||
REQUIRE(plus_v(icol, ilev) == qc_v(icol, ilev) + qv_v(icol, ilev)); | ||
// Check product | ||
REQUIRE(prod_v(icol, ilev) == qc_v(icol, ilev) * qv_v(icol, ilev)); | ||
} | ||
} | ||
|
||
// redundant, why not | ||
qc.update(qv, 1, 1); | ||
views_are_equal(qc, plus_diag_f); | ||
} | ||
|
||
} // namespace scream |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that it is available, should we remind the user about aliasing, and "encourage" to use it for binary-op diags?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, I would keep the docs for each feature separate. I will write an overarching docs page for tips to use them together, but I think each page should be isolated to what it is offering as a standalone feature