Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions components/eamxx/docs/user/diags/conditional_sampling.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Conditional sampling
# Conditional sampling diagnostics

The conditional sampling diagnostic allows you to extract field values
where a specified condition is met, filling other locations with a
Expand Down Expand Up @@ -55,7 +55,7 @@ Use the special condition field name `lev`.

Count the number of grid points where a condition is met.
Use the special input field name `count`. The output will be `1.0`
where the condition is satisfied and the fill value elsewhere.
where the condition is satisfied and `0.0` elsewhere.
This is particularly useful when combined with horizontal or vertical
reductions to count occurrences of specific conditions.

Expand Down
2 changes: 2 additions & 0 deletions components/eamxx/docs/user/diags/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ are designed generically and composably, and are requestable by users.
## Available diagnostics

- [Field contraction](field_contraction.md)
- [Conditional sampling](conditional_sampling.md)
- [Binary arithmetics](binary_ops.md)
- [Vertical derivative](vert_derivative.md)

More details to follow.
14 changes: 7 additions & 7 deletions components/eamxx/docs/user/io_aliases.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# EAMxx Field Aliasing Feature
# Field Aliasing

This document demonstrates the field aliasing feature for EAMxx I/O operations.

Expand Down Expand Up @@ -52,16 +52,16 @@ When using aliases:
1. **NetCDF Variables**: The netcdf file will contain variables
named according to the aliases

- `LWP` instead of `LiqWaterPath`
- `T` instead of `T_mid`
- `RH` instead of `RelativeHumidity`
- `LWP` instead of `LiqWaterPath`
- `T` instead of `T_mid`
- `RH` instead of `RelativeHumidity`

2. **Internal Processing**: All internal model operations use the
original field names

- Field validation uses `LiqWaterPath`, `T_mid`, etc.
- Diagnostic calculations use original names
- Memory management uses original field structures
- Field validation uses `LiqWaterPath`, `T_mid`, etc.
- Diagnostic calculations use original names
- Memory management uses original field structures

3. **Metadata**: Variable attributes (units, long_name, etc.)
are preserved from the original fields, and `eamxx_name`
Expand Down
56 changes: 32 additions & 24 deletions components/eamxx/src/diagnostics/conditional_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ void apply_conditional_sampling_1d(
const std::string &condition_op, const Real &condition_val,
const Real &fill_value = constants::fill_value<Real>) {

// if fill_value is 0, we are counting
const auto is_counting = (fill_value == 0);
const auto output_v = output_field.get_view<Real *>();
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>();
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>() : output_v;
const auto input_v = input_field.get_view<const Real *>();
const auto condition_v = condition_field.get_view<const Real *>();

Expand All @@ -59,13 +61,13 @@ void apply_conditional_sampling_1d(
bool condition_masked = has_condition_mask && (condition_mask_v(idx) == 0);
if (input_masked || condition_masked) {
output_v(idx) = fill_value;
mask_v(idx) = 0;
if (!is_counting) mask_v(idx) = 0;
} else if (evaluate_condition(condition_v(idx), op_code, condition_val)) {
output_v(idx) = input_v(idx);
mask_v(idx) = 1;
if (!is_counting) mask_v(idx) = 1;
} else {
output_v(idx) = fill_value;
mask_v(idx) = 0;
if (!is_counting) mask_v(idx) = 0;
}
});
}
Expand All @@ -76,8 +78,11 @@ void apply_conditional_sampling_2d(
const std::string &condition_op, const Real &condition_val,
const Real &fill_value = constants::fill_value<Real>) {

// if fill_value is 0, we are counting
const auto is_counting = (fill_value == 0);

const auto output_v = output_field.get_view<Real **>();
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>();
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>() : output_v;
const auto input_v = input_field.get_view<const Real **>();
const auto condition_v = condition_field.get_view<const Real **>();

Expand All @@ -102,13 +107,13 @@ void apply_conditional_sampling_2d(
bool condition_masked = has_condition_mask && (condition_mask_v(icol, ilev) == 0);
if (input_masked || condition_masked) {
output_v(icol, ilev) = fill_value;
mask_v(icol, ilev) = 0;
if (!is_counting) mask_v(icol, ilev) = 0;
} else if (evaluate_condition(condition_v(icol, ilev), op_code, condition_val)) {
output_v(icol, ilev) = input_v(icol, ilev);
mask_v(icol, ilev) = 1;
if (!is_counting) mask_v(icol, ilev) = 1;
} else {
output_v(icol, ilev) = fill_value;
mask_v(icol, ilev) = 0;
if (!is_counting) mask_v(icol, ilev) = 0;
}
});
}
Expand All @@ -119,8 +124,11 @@ void apply_conditional_sampling_1d_lev(
const std::string &condition_op, const Real &condition_val,
const Real &fill_value = constants::fill_value<Real>) {

// if fill_value is 0, we are counting
const auto is_counting = (fill_value == 0);

const auto output_v = output_field.get_view<Real *>();
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>();
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>() : output_v;
const auto input_v = input_field.get_view<const Real *>();

// Try to get input mask, if present
Expand All @@ -140,13 +148,13 @@ void apply_conditional_sampling_1d_lev(
bool input_masked = has_input_mask && (input_mask_v(idx) == 0);
if (input_masked) {
output_v(idx) = fill_value;
mask_v(idx) = 0;
if (!is_counting) mask_v(idx) = 0;
} else if (evaluate_condition(level_idx, op_code, condition_val)) {
output_v(idx) = input_v(idx);
mask_v(idx) = 1;
if (!is_counting) mask_v(idx) = 1;
} else {
output_v(idx) = fill_value;
mask_v(idx) = 0;
if (!is_counting) mask_v(idx) = 0;
}
});
}
Expand All @@ -157,8 +165,11 @@ void apply_conditional_sampling_2d_lev(
const std::string &condition_op, const Real &condition_val,
const Real &fill_value = constants::fill_value<Real>) {

// if fill_value is 0, we are counting
const auto is_counting = (fill_value == 0);

const auto output_v = output_field.get_view<Real **>();
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>();
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>() : output_v;
const auto input_v = input_field.get_view<const Real **>();

// Try to get input mask, if present
Expand All @@ -181,13 +192,13 @@ void apply_conditional_sampling_2d_lev(
bool input_masked = has_input_mask && (input_mask_v(icol, ilev) == 0);
if (input_masked) {
output_v(icol, ilev) = fill_value;
mask_v(icol, ilev) = 0;
if (!is_counting) mask_v(icol, ilev) = 0;
} else if (evaluate_condition(level_idx, op_code, condition_val)) {
output_v(icol, ilev) = input_v(icol, ilev);
mask_v(icol, ilev) = 1;
if (!is_counting) mask_v(icol, ilev) = 1;
} else {
output_v(icol, ilev) = fill_value;
mask_v(icol, ilev) = 0;
if (!is_counting) mask_v(icol, ilev) = 0;
}
});
}
Expand Down Expand Up @@ -259,17 +270,14 @@ void ConditionalSampling::initialize_impl(const RunType /*run_type*/) {

const auto var_fill_value = constants::fill_value<Real>;
m_mask_val = m_params.get<double>("mask_value", var_fill_value);

m_diagnostic_output.get_header().set_extra_data("mask_data", diag_mask);
m_diagnostic_output.get_header().set_extra_data("mask_value", m_mask_val);

if (m_input_f != "count") {
m_diagnostic_output.get_header().set_extra_data("mask_data", diag_mask);
m_diagnostic_output.get_header().set_extra_data("mask_value", m_mask_val);
}
// Special case: if the input field is "count", let's create a field of 1s
if (m_input_f == "count") {
ones = m_diagnostic_output.clone("count_ones");
ones.deep_copy(1.0);
auto ones_mask = ones.clone("count_ones_mask");
ones.get_header().set_extra_data("mask_data", ones_mask);
ones.get_header().set_extra_data("mask_value", m_mask_val);
}

// Special case: if condition field is "lev", we don't need to check layout compatibility
Expand Down Expand Up @@ -308,7 +316,7 @@ void ConditionalSampling::compute_diagnostic_impl() {
"Valid operators are: eq, ==, ne, !=, gt, >, ge, >=, lt, <, le, <=\n");

// Get the fill value from constants
const Real fill_value = m_mask_val;
const Real fill_value = (m_input_f == "count") ? 0.0 : m_mask_val;

// Determine field layout and apply appropriate conditional sampling
const auto &layout = f.get_header().get_identifier().get_layout();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,13 @@ TEST_CASE("conditional_sampling") {
auto qc12_v = qc12.get_view<const Real *, Host>();
auto qc21_v = qc21.get_view<const Real **, Host>();

// Check the results - count should be 1.0 where condition is met, fill_value otherwise
// Check the results - count should be 1.0 where condition is met, 0 otherwise
for (int ilev = 0; ilev < nlevs; ++ilev) {
// check count for qc12
if (qc12_v(ilev) > comp_val) {
REQUIRE(count_diag12_v(ilev) == 1.0);
} else {
REQUIRE(count_diag12_v(ilev) == fill_value);
REQUIRE(count_diag12_v(ilev) == 0.0);
}
}

Expand All @@ -311,21 +311,21 @@ TEST_CASE("conditional_sampling") {
if (qc11_v(icol) > comp_val) {
REQUIRE(count_diag11_v(icol) == 1.0);
} else {
REQUIRE(count_diag11_v(icol) == fill_value);
REQUIRE(count_diag11_v(icol) == 0.0);
}

for (int ilev = 0; ilev < nlevs; ++ilev) {
// check count for qc21
if (qc21_v(icol, ilev) > comp_val) {
REQUIRE(count_diag21_v(icol, ilev) == 1.0);
} else {
REQUIRE(count_diag21_v(icol, ilev) == fill_value);
REQUIRE(count_diag21_v(icol, ilev) == 0.0);
}
// check count again, but the negative
if (qc21_v(icol, ilev) <= comp_val) {
REQUIRE_FALSE(count_diag21_v(icol, ilev) == 1.0);
} else {
REQUIRE_FALSE(count_diag21_v(icol, ilev) == fill_value);
REQUIRE_FALSE(count_diag21_v(icol, ilev) == 0.0);
}
}
}
Expand Down
Loading