Skip to content

Commit 70c6ab0

Browse files
authored
Merge branch 'mahf708/eamxx/diags-fixes-n-docs' (PR #7632)
[BFB]
2 parents 137af93 + 176eb9d commit 70c6ab0

File tree

5 files changed

+48
-38
lines changed

5 files changed

+48
-38
lines changed

components/eamxx/docs/user/diags/conditional_sampling.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Conditional sampling
1+
# Conditional sampling diagnostics
22

33
The conditional sampling diagnostic allows you to extract field values
44
where a specified condition is met, filling other locations with a
@@ -55,7 +55,7 @@ Use the special condition field name `lev`.
5555

5656
Count the number of grid points where a condition is met.
5757
Use the special input field name `count`. The output will be `1.0`
58-
where the condition is satisfied and the fill value elsewhere.
58+
where the condition is satisfied and `0.0` elsewhere.
5959
This is particularly useful when combined with horizontal or vertical
6060
reductions to count occurrences of specific conditions.
6161

components/eamxx/docs/user/diags/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ are designed generically and composably, and are requestable by users.
77
## Available diagnostics
88

99
- [Field contraction](field_contraction.md)
10+
- [Conditional sampling](conditional_sampling.md)
1011
- [Binary arithmetics](binary_ops.md)
12+
- [Vertical derivative](vert_derivative.md)
1113

1214
More details to follow.

components/eamxx/docs/user/io_aliases.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# EAMxx Field Aliasing Feature
1+
# Field Aliasing
22

33
This document demonstrates the field aliasing feature for EAMxx I/O operations.
44

@@ -52,16 +52,16 @@ When using aliases:
5252
1. **NetCDF Variables**: The netcdf file will contain variables
5353
named according to the aliases
5454
55-
- `LWP` instead of `LiqWaterPath`
56-
- `T` instead of `T_mid`
57-
- `RH` instead of `RelativeHumidity`
55+
- `LWP` instead of `LiqWaterPath`
56+
- `T` instead of `T_mid`
57+
- `RH` instead of `RelativeHumidity`
5858

5959
2. **Internal Processing**: All internal model operations use the
6060
original field names
6161

62-
- Field validation uses `LiqWaterPath`, `T_mid`, etc.
63-
- Diagnostic calculations use original names
64-
- Memory management uses original field structures
62+
- Field validation uses `LiqWaterPath`, `T_mid`, etc.
63+
- Diagnostic calculations use original names
64+
- Memory management uses original field structures
6565

6666
3. **Metadata**: Variable attributes (units, long_name, etc.)
6767
are preserved from the original fields, and `eamxx_name`

components/eamxx/src/diagnostics/conditional_sampling.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ void apply_conditional_sampling_1d(
3737
const std::string &condition_op, const Real &condition_val,
3838
const Real &fill_value = constants::fill_value<Real>) {
3939

40+
// if fill_value is 0, we are counting
41+
const auto is_counting = (fill_value == 0);
4042
const auto output_v = output_field.get_view<Real *>();
41-
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>();
43+
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>() : output_v;
4244
const auto input_v = input_field.get_view<const Real *>();
4345
const auto condition_v = condition_field.get_view<const Real *>();
4446

@@ -59,13 +61,13 @@ void apply_conditional_sampling_1d(
5961
bool condition_masked = has_condition_mask && (condition_mask_v(idx) == 0);
6062
if (input_masked || condition_masked) {
6163
output_v(idx) = fill_value;
62-
mask_v(idx) = 0;
64+
if (!is_counting) mask_v(idx) = 0;
6365
} else if (evaluate_condition(condition_v(idx), op_code, condition_val)) {
6466
output_v(idx) = input_v(idx);
65-
mask_v(idx) = 1;
67+
if (!is_counting) mask_v(idx) = 1;
6668
} else {
6769
output_v(idx) = fill_value;
68-
mask_v(idx) = 0;
70+
if (!is_counting) mask_v(idx) = 0;
6971
}
7072
});
7173
}
@@ -76,8 +78,11 @@ void apply_conditional_sampling_2d(
7678
const std::string &condition_op, const Real &condition_val,
7779
const Real &fill_value = constants::fill_value<Real>) {
7880

81+
// if fill_value is 0, we are counting
82+
const auto is_counting = (fill_value == 0);
83+
7984
const auto output_v = output_field.get_view<Real **>();
80-
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>();
85+
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>() : output_v;
8186
const auto input_v = input_field.get_view<const Real **>();
8287
const auto condition_v = condition_field.get_view<const Real **>();
8388

@@ -102,13 +107,13 @@ void apply_conditional_sampling_2d(
102107
bool condition_masked = has_condition_mask && (condition_mask_v(icol, ilev) == 0);
103108
if (input_masked || condition_masked) {
104109
output_v(icol, ilev) = fill_value;
105-
mask_v(icol, ilev) = 0;
110+
if (!is_counting) mask_v(icol, ilev) = 0;
106111
} else if (evaluate_condition(condition_v(icol, ilev), op_code, condition_val)) {
107112
output_v(icol, ilev) = input_v(icol, ilev);
108-
mask_v(icol, ilev) = 1;
113+
if (!is_counting) mask_v(icol, ilev) = 1;
109114
} else {
110115
output_v(icol, ilev) = fill_value;
111-
mask_v(icol, ilev) = 0;
116+
if (!is_counting) mask_v(icol, ilev) = 0;
112117
}
113118
});
114119
}
@@ -119,8 +124,11 @@ void apply_conditional_sampling_1d_lev(
119124
const std::string &condition_op, const Real &condition_val,
120125
const Real &fill_value = constants::fill_value<Real>) {
121126

127+
// if fill_value is 0, we are counting
128+
const auto is_counting = (fill_value == 0);
129+
122130
const auto output_v = output_field.get_view<Real *>();
123-
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>();
131+
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real *>() : output_v;
124132
const auto input_v = input_field.get_view<const Real *>();
125133

126134
// Try to get input mask, if present
@@ -140,13 +148,13 @@ void apply_conditional_sampling_1d_lev(
140148
bool input_masked = has_input_mask && (input_mask_v(idx) == 0);
141149
if (input_masked) {
142150
output_v(idx) = fill_value;
143-
mask_v(idx) = 0;
151+
if (!is_counting) mask_v(idx) = 0;
144152
} else if (evaluate_condition(level_idx, op_code, condition_val)) {
145153
output_v(idx) = input_v(idx);
146-
mask_v(idx) = 1;
154+
if (!is_counting) mask_v(idx) = 1;
147155
} else {
148156
output_v(idx) = fill_value;
149-
mask_v(idx) = 0;
157+
if (!is_counting) mask_v(idx) = 0;
150158
}
151159
});
152160
}
@@ -157,8 +165,11 @@ void apply_conditional_sampling_2d_lev(
157165
const std::string &condition_op, const Real &condition_val,
158166
const Real &fill_value = constants::fill_value<Real>) {
159167

168+
// if fill_value is 0, we are counting
169+
const auto is_counting = (fill_value == 0);
170+
160171
const auto output_v = output_field.get_view<Real **>();
161-
const auto mask_v = output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>();
172+
const auto mask_v = !is_counting ? output_field.get_header().get_extra_data<Field>("mask_data").get_view<Real **>() : output_v;
162173
const auto input_v = input_field.get_view<const Real **>();
163174

164175
// Try to get input mask, if present
@@ -181,13 +192,13 @@ void apply_conditional_sampling_2d_lev(
181192
bool input_masked = has_input_mask && (input_mask_v(icol, ilev) == 0);
182193
if (input_masked) {
183194
output_v(icol, ilev) = fill_value;
184-
mask_v(icol, ilev) = 0;
195+
if (!is_counting) mask_v(icol, ilev) = 0;
185196
} else if (evaluate_condition(level_idx, op_code, condition_val)) {
186197
output_v(icol, ilev) = input_v(icol, ilev);
187-
mask_v(icol, ilev) = 1;
198+
if (!is_counting) mask_v(icol, ilev) = 1;
188199
} else {
189200
output_v(icol, ilev) = fill_value;
190-
mask_v(icol, ilev) = 0;
201+
if (!is_counting) mask_v(icol, ilev) = 0;
191202
}
192203
});
193204
}
@@ -259,17 +270,14 @@ void ConditionalSampling::initialize_impl(const RunType /*run_type*/) {
259270

260271
const auto var_fill_value = constants::fill_value<Real>;
261272
m_mask_val = m_params.get<double>("mask_value", var_fill_value);
262-
263-
m_diagnostic_output.get_header().set_extra_data("mask_data", diag_mask);
264-
m_diagnostic_output.get_header().set_extra_data("mask_value", m_mask_val);
265-
273+
if (m_input_f != "count") {
274+
m_diagnostic_output.get_header().set_extra_data("mask_data", diag_mask);
275+
m_diagnostic_output.get_header().set_extra_data("mask_value", m_mask_val);
276+
}
266277
// Special case: if the input field is "count", let's create a field of 1s
267278
if (m_input_f == "count") {
268279
ones = m_diagnostic_output.clone("count_ones");
269280
ones.deep_copy(1.0);
270-
auto ones_mask = ones.clone("count_ones_mask");
271-
ones.get_header().set_extra_data("mask_data", ones_mask);
272-
ones.get_header().set_extra_data("mask_value", m_mask_val);
273281
}
274282

275283
// Special case: if condition field is "lev", we don't need to check layout compatibility
@@ -308,7 +316,7 @@ void ConditionalSampling::compute_diagnostic_impl() {
308316
"Valid operators are: eq, ==, ne, !=, gt, >, ge, >=, lt, <, le, <=\n");
309317

310318
// Get the fill value from constants
311-
const Real fill_value = m_mask_val;
319+
const Real fill_value = (m_input_f == "count") ? 0.0 : m_mask_val;
312320

313321
// Determine field layout and apply appropriate conditional sampling
314322
const auto &layout = f.get_header().get_identifier().get_layout();

components/eamxx/src/diagnostics/tests/conditional_sampling_test.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,13 +296,13 @@ TEST_CASE("conditional_sampling") {
296296
auto qc12_v = qc12.get_view<const Real *, Host>();
297297
auto qc21_v = qc21.get_view<const Real **, Host>();
298298

299-
// Check the results - count should be 1.0 where condition is met, fill_value otherwise
299+
// Check the results - count should be 1.0 where condition is met, 0 otherwise
300300
for (int ilev = 0; ilev < nlevs; ++ilev) {
301301
// check count for qc12
302302
if (qc12_v(ilev) > comp_val) {
303303
REQUIRE(count_diag12_v(ilev) == 1.0);
304304
} else {
305-
REQUIRE(count_diag12_v(ilev) == fill_value);
305+
REQUIRE(count_diag12_v(ilev) == 0.0);
306306
}
307307
}
308308

@@ -311,21 +311,21 @@ TEST_CASE("conditional_sampling") {
311311
if (qc11_v(icol) > comp_val) {
312312
REQUIRE(count_diag11_v(icol) == 1.0);
313313
} else {
314-
REQUIRE(count_diag11_v(icol) == fill_value);
314+
REQUIRE(count_diag11_v(icol) == 0.0);
315315
}
316316

317317
for (int ilev = 0; ilev < nlevs; ++ilev) {
318318
// check count for qc21
319319
if (qc21_v(icol, ilev) > comp_val) {
320320
REQUIRE(count_diag21_v(icol, ilev) == 1.0);
321321
} else {
322-
REQUIRE(count_diag21_v(icol, ilev) == fill_value);
322+
REQUIRE(count_diag21_v(icol, ilev) == 0.0);
323323
}
324324
// check count again, but the negative
325325
if (qc21_v(icol, ilev) <= comp_val) {
326326
REQUIRE_FALSE(count_diag21_v(icol, ilev) == 1.0);
327327
} else {
328-
REQUIRE_FALSE(count_diag21_v(icol, ilev) == fill_value);
328+
REQUIRE_FALSE(count_diag21_v(icol, ilev) == 0.0);
329329
}
330330
}
331331
}

0 commit comments

Comments
 (0)