Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions components/eamxx/docs/user/diags/conditional_sampling.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Conditional sampling
# Conditional sampling diagnostics

The conditional sampling diagnostic allows you to extract field values
where a specified condition is met, filling other locations with a
Expand Down Expand Up @@ -55,7 +55,7 @@ Use the special condition field name `lev`.

Count the number of grid points where a condition is met.
Use the special input field name `count`. The output will be `1.0`
where the condition is satisfied and the fill value elsewhere.
where the condition is satisfied and `0.0` elsewhere.
This is particularly useful when combined with horizontal or vertical
reductions to count occurrences of specific conditions.

Expand Down
2 changes: 2 additions & 0 deletions components/eamxx/docs/user/diags/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ are designed generically and composably, and are requestable by users.
## Available diagnostics

- [Field contraction](field_contraction.md)
- [Conditional sampling](conditional_sampling.md)
- [Binary arithmetics](binary_ops.md)
- [Vertical derivative](vertical_derivative.md)

More details to follow.
14 changes: 7 additions & 7 deletions components/eamxx/docs/user/io_aliases.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# EAMxx Field Aliasing Feature
# Field Aliasing

This document demonstrates the field aliasing feature for EAMxx I/O operations.

Expand Down Expand Up @@ -52,16 +52,16 @@ When using aliases:
1. **NetCDF Variables**: The netcdf file will contain variables
named according to the aliases

- `LWP` instead of `LiqWaterPath`
- `T` instead of `T_mid`
- `RH` instead of `RelativeHumidity`
- `LWP` instead of `LiqWaterPath`
- `T` instead of `T_mid`
- `RH` instead of `RelativeHumidity`

2. **Internal Processing**: All internal model operations use the
original field names

- Field validation uses `LiqWaterPath`, `T_mid`, etc.
- Diagnostic calculations use original names
- Memory management uses original field structures
- Field validation uses `LiqWaterPath`, `T_mid`, etc.
- Diagnostic calculations use original names
- Memory management uses original field structures

3. **Metadata**: Variable attributes (units, long_name, etc.)
are preserved from the original fields, and `eamxx_name`
Expand Down
56 changes: 32 additions & 24 deletions components/eamxx/src/diagnostics/conditional_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ void apply_conditional_sampling_1d(
const std::string &condition_op, const Real &condition_val,
const Real &fill_value = constants::fill_value<Real>) {

// if fill_value is 0, we are counting
const auto is_counting = (fill_value == 0);
const auto output_v = output_field.get_view<Real *>();
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>();
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>() : output_v;
const auto input_v = input_field.get_view<const Real *>();
const auto condition_v = condition_field.get_view<const Real *>();

Expand All @@ -59,13 +61,13 @@ void apply_conditional_sampling_1d(
bool condition_masked = has_condition_mask && (condition_mask_v(idx) == 0);
if (input_masked || condition_masked) {
output_v(idx) = fill_value;
mask_v(idx) = 0;
if (!is_counting) mask_v(idx) = 0;
} else if (evaluate_condition(condition_v(idx), op_code, condition_val)) {
output_v(idx) = input_v(idx);
mask_v(idx) = 1;
if (!is_counting) mask_v(idx) = 1;
} else {
output_v(idx) = fill_value;
mask_v(idx) = 0;
if (!is_counting) mask_v(idx) = 0;
}
});
}
Expand All @@ -76,8 +78,11 @@ void apply_conditional_sampling_2d(
const std::string &condition_op, const Real &condition_val,
const Real &fill_value = constants::fill_value<Real>) {

// if fill_value is 0, we are counting
const auto is_counting = (fill_value == 0);

const auto output_v = output_field.get_view<Real **>();
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>();
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>() : output_v;
const auto input_v = input_field.get_view<const Real **>();
const auto condition_v = condition_field.get_view<const Real **>();

Expand All @@ -102,13 +107,13 @@ void apply_conditional_sampling_2d(
bool condition_masked = has_condition_mask && (condition_mask_v(icol, ilev) == 0);
if (input_masked || condition_masked) {
output_v(icol, ilev) = fill_value;
mask_v(icol, ilev) = 0;
if (!is_counting) mask_v(icol, ilev) = 0;
} else if (evaluate_condition(condition_v(icol, ilev), op_code, condition_val)) {
output_v(icol, ilev) = input_v(icol, ilev);
mask_v(icol, ilev) = 1;
if (!is_counting) mask_v(icol, ilev) = 1;
} else {
output_v(icol, ilev) = fill_value;
mask_v(icol, ilev) = 0;
if (!is_counting) mask_v(icol, ilev) = 0;
}
});
}
Expand All @@ -119,8 +124,11 @@ void apply_conditional_sampling_1d_lev(
const std::string &condition_op, const Real &condition_val,
const Real &fill_value = constants::fill_value<Real>) {

// if fill_value is 0, we are counting
const auto is_counting = (fill_value == 0);

const auto output_v = output_field.get_view<Real *>();
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>();
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>() : output_v;
const auto input_v = input_field.get_view<const Real *>();

// Try to get input mask, if present
Expand All @@ -140,13 +148,13 @@ void apply_conditional_sampling_1d_lev(
bool input_masked = has_input_mask && (input_mask_v(idx) == 0);
if (input_masked) {
output_v(idx) = fill_value;
mask_v(idx) = 0;
if (!is_counting) mask_v(idx) = 0;
} else if (evaluate_condition(level_idx, op_code, condition_val)) {
output_v(idx) = input_v(idx);
mask_v(idx) = 1;
if (!is_counting) mask_v(idx) = 1;
} else {
output_v(idx) = fill_value;
mask_v(idx) = 0;
if (!is_counting) mask_v(idx) = 0;
}
});
}
Expand All @@ -157,8 +165,11 @@ void apply_conditional_sampling_2d_lev(
const std::string &condition_op, const Real &condition_val,
const Real &fill_value = constants::fill_value<Real>) {

// if fill_value is 0, we are counting
const auto is_counting = (fill_value == 0);

const auto output_v = output_field.get_view<Real **>();
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>();
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>() : output_v;
const auto input_v = input_field.get_view<const Real **>();

// Try to get input mask, if present
Expand All @@ -181,13 +192,13 @@ void apply_conditional_sampling_2d_lev(
bool input_masked = has_input_mask && (input_mask_v(icol, ilev) == 0);
if (input_masked) {
output_v(icol, ilev) = fill_value;
mask_v(icol, ilev) = 0;
if (!is_counting) mask_v(icol, ilev) = 0;
} else if (evaluate_condition(level_idx, op_code, condition_val)) {
output_v(icol, ilev) = input_v(icol, ilev);
mask_v(icol, ilev) = 1;
if (!is_counting) mask_v(icol, ilev) = 1;
} else {
output_v(icol, ilev) = fill_value;
mask_v(icol, ilev) = 0;
if (!is_counting) mask_v(icol, ilev) = 0;
}
});
}
Expand Down Expand Up @@ -259,17 +270,14 @@ void ConditionalSampling::initialize_impl(const RunType /*run_type*/) {

const auto var_fill_value = constants::fill_value<Real>;
m_mask_val = m_params.get<double>("mask_value", var_fill_value);

m_diagnostic_output.get_header().set_extra_data("mask_data", diag_mask);
m_diagnostic_output.get_header().set_extra_data("mask_value", m_mask_val);

if (m_input_f != "count") {
m_diagnostic_output.get_header().set_extra_data("mask_data", diag_mask);
m_diagnostic_output.get_header().set_extra_data("mask_value", m_mask_val);
}
// Special case: if the input field is "count", let's create a field of 1s
if (m_input_f == "count") {
ones = m_diagnostic_output.clone("count_ones");
ones.deep_copy(1.0);
auto ones_mask = ones.clone("count_ones_mask");
ones.get_header().set_extra_data("mask_data", ones_mask);
ones.get_header().set_extra_data("mask_value", m_mask_val);
}

// Special case: if condition field is "lev", we don't need to check layout compatibility
Expand Down Expand Up @@ -308,7 +316,7 @@ void ConditionalSampling::compute_diagnostic_impl() {
"Valid operators are: eq, ==, ne, !=, gt, >, ge, >=, lt, <, le, <=\n");

// Get the fill value from constants
const Real fill_value = m_mask_val;
const Real fill_value = (m_input_f == "count") ? 0.0 : m_mask_val;

// Determine field layout and apply appropriate conditional sampling
const auto &layout = f.get_header().get_identifier().get_layout();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,13 @@ TEST_CASE("conditional_sampling") {
auto qc12_v = qc12.get_view<const Real *, Host>();
auto qc21_v = qc21.get_view<const Real **, Host>();

// Check the results - count should be 1.0 where condition is met, fill_value otherwise
// Check the results - count should be 1.0 where condition is met, 0 otherwise
for (int ilev = 0; ilev < nlevs; ++ilev) {
// check count for qc12
if (qc12_v(ilev) > comp_val) {
REQUIRE(count_diag12_v(ilev) == 1.0);
} else {
REQUIRE(count_diag12_v(ilev) == fill_value);
REQUIRE(count_diag12_v(ilev) == 0.0);
}
}

Expand All @@ -311,21 +311,21 @@ TEST_CASE("conditional_sampling") {
if (qc11_v(icol) > comp_val) {
REQUIRE(count_diag11_v(icol) == 1.0);
} else {
REQUIRE(count_diag11_v(icol) == fill_value);
REQUIRE(count_diag11_v(icol) == 0.0);
}

for (int ilev = 0; ilev < nlevs; ++ilev) {
// check count for qc21
if (qc21_v(icol, ilev) > comp_val) {
REQUIRE(count_diag21_v(icol, ilev) == 1.0);
} else {
REQUIRE(count_diag21_v(icol, ilev) == fill_value);
REQUIRE(count_diag21_v(icol, ilev) == 0.0);
}
// check count again, but the negative
if (qc21_v(icol, ilev) <= comp_val) {
REQUIRE_FALSE(count_diag21_v(icol, ilev) == 1.0);
} else {
REQUIRE_FALSE(count_diag21_v(icol, ilev) == fill_value);
REQUIRE_FALSE(count_diag21_v(icol, ilev) == 0.0);
}
}
}
Expand Down
Loading