Skip to content

Commit 1e304ec

Browse files
committed
EAMxx: add count docs and tests
1 parent a70f66d commit 1e304ec

File tree

3 files changed

+109
-1
lines changed

3 files changed

+109
-1
lines changed

components/eamxx/docs/user/diags/conditional_sampling.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,21 @@ Use the special condition field name `lev`.
5151
- `p_mid_where_lev_eq_0`
5252
- `qv_where_lev_le_10`
5353

54+
## Count-based conditional sampling
55+
56+
Count the number of grid points where a condition is met.
57+
Use the special input field name `count`. The output will be `1.0`
58+
where the condition is satisfied and the fill value elsewhere.
59+
This is particularly useful when combined with horizontal or vertical
60+
reductions to count occurrences of specific conditions.
61+
62+
**Examples**:
63+
64+
- `count_where_qv_gt_0.01`
65+
- `count_where_T_mid_le_273.15`
66+
- `count_where_p_mid_lt_50000`
67+
- `count_where_lev_gt_5`
68+
5469
## Caveats
5570

5671
- For now, we only support 1D or 2D fields.
@@ -76,12 +91,18 @@ averaging_type: instant
7691
fields:
7792
physics_pg2:
7893
field_names:
94+
# Field-based conditional sampling
7995
- T_mid_where_qv_gt_0.01
8096
- p_mid_where_T_mid_le_273.15
8197
- qv_where_p_mid_lt_50000
98+
# Level-based conditional sampling
8299
- T_mid_where_lev_gt_5
83100
- p_mid_where_lev_eq_0
84101
- qv_where_lev_le_10
102+
# Count-based conditional sampling
103+
- count_where_qv_gt_0.01
104+
- count_where_T_mid_le_273.15
105+
- count_where_lev_gt_5
85106
output_control:
86107
frequency: 6
87108
frequency_units: nhours

components/eamxx/src/diagnostics/conditional_sampling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "diagnostics/conditional_sampling.hpp"
22
#include "share/util/eamxx_universal_constants.hpp"
3-
#include <ekat/kokkos/ekat_kokkos_utils.hpp>
3+
#include <ekat_team_policy_utils.hpp>
44
#include <string>
55

66
namespace scream {

components/eamxx/src/diagnostics/tests/conditional_sampling_test.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,93 @@ TEST_CASE("conditional_sampling") {
245245
}
246246
}
247247
}
248+
SECTION("count_conditional") {
249+
const auto comp_val = 0.001;
250+
251+
// Test count conditional sampling - count grid points where condition is met
252+
params.clear();
253+
params.set("grid_name", grid->name());
254+
params.set<std::string>("input_field", "count");
255+
params.set<std::string>("condition_field", "qc");
256+
params.set<std::string>("condition_operator", "gt");
257+
params.set<std::string>("condition_value", std::to_string(comp_val));
258+
259+
// Set time for qc and randomize its values
260+
qc11.get_header().get_tracking().update_time_stamp(t0);
261+
qc12.get_header().get_tracking().update_time_stamp(t0);
262+
qc21.get_header().get_tracking().update_time_stamp(t0);
263+
randomize(qc11, engine, pdf);
264+
randomize(qc12, engine, pdf);
265+
randomize(qc21, engine, pdf);
266+
267+
// Create and set up the diagnostic for count
268+
auto count_diag11 = diag_factory.create("ConditionalSampling", comm, params);
269+
auto count_diag12 = diag_factory.create("ConditionalSampling", comm, params);
270+
auto count_diag21 = diag_factory.create("ConditionalSampling", comm, params);
271+
count_diag11->set_grids(gm);
272+
count_diag12->set_grids(gm);
273+
count_diag21->set_grids(gm);
274+
275+
// Set the fields for each diagnostic
276+
count_diag11->set_required_field(qc11);
277+
count_diag11->initialize(t0, RunType::Initial);
278+
count_diag11->compute_diagnostic();
279+
auto count_diag11_f = count_diag11->get_diagnostic();
280+
count_diag11_f.sync_to_host();
281+
auto count_diag11_v = count_diag11_f.get_view<const Real *, Host>();
282+
283+
count_diag12->set_required_field(qc12);
284+
count_diag12->initialize(t0, RunType::Initial);
285+
count_diag12->compute_diagnostic();
286+
auto count_diag12_f = count_diag12->get_diagnostic();
287+
count_diag12_f.sync_to_host();
288+
auto count_diag12_v = count_diag12_f.get_view<const Real *, Host>();
289+
290+
count_diag21->set_required_field(qc21);
291+
count_diag21->initialize(t0, RunType::Initial);
292+
count_diag21->compute_diagnostic();
293+
auto count_diag21_f = count_diag21->get_diagnostic();
294+
count_diag21_f.sync_to_host();
295+
auto count_diag21_v = count_diag21_f.get_view<const Real **, Host>();
296+
297+
auto qc11_v = qc11.get_view<const Real *, Host>();
298+
auto qc12_v = qc12.get_view<const Real *, Host>();
299+
auto qc21_v = qc21.get_view<const Real **, Host>();
300+
301+
// Check the results - count should be 1.0 where condition is met, fill_value otherwise
302+
for (int ilev = 0; ilev < nlevs; ++ilev) {
303+
// check count for qc12
304+
if (qc12_v(ilev) > comp_val) {
305+
REQUIRE(count_diag12_v(ilev) == 1.0);
306+
} else {
307+
REQUIRE(count_diag12_v(ilev) == fill_value);
308+
}
309+
}
310+
311+
for (int icol = 0; icol < ngcols; ++icol) {
312+
// Check count for qc11
313+
if (qc11_v(icol) > comp_val) {
314+
REQUIRE(count_diag11_v(icol) == 1.0);
315+
} else {
316+
REQUIRE(count_diag11_v(icol) == fill_value);
317+
}
318+
319+
for (int ilev = 0; ilev < nlevs; ++ilev) {
320+
// check count for qc21
321+
if (qc21_v(icol, ilev) > comp_val) {
322+
REQUIRE(count_diag21_v(icol, ilev) == 1.0);
323+
} else {
324+
REQUIRE(count_diag21_v(icol, ilev) == fill_value);
325+
}
326+
// check count again, but the negative
327+
if (qc21_v(icol, ilev) <= comp_val) {
328+
REQUIRE_FALSE(count_diag21_v(icol, ilev) == 1.0);
329+
} else {
330+
REQUIRE_FALSE(count_diag21_v(icol, ilev) == fill_value);
331+
}
332+
}
333+
}
334+
}
248335
}
249336

250337
} // namespace scream

0 commit comments

Comments
 (0)