@@ -245,6 +245,93 @@ TEST_CASE("conditional_sampling") {
245
245
}
246
246
}
247
247
}
248
+ SECTION (" count_conditional" ) {
249
+ const auto comp_val = 0.001 ;
250
+
251
+ // Test count conditional sampling - count grid points where condition is met
252
+ params.clear ();
253
+ params.set (" grid_name" , grid->name ());
254
+ params.set <std::string>(" input_field" , " count" );
255
+ params.set <std::string>(" condition_field" , " qc" );
256
+ params.set <std::string>(" condition_operator" , " gt" );
257
+ params.set <std::string>(" condition_value" , std::to_string (comp_val));
258
+
259
+ // Set time for qc and randomize its values
260
+ qc11.get_header ().get_tracking ().update_time_stamp (t0);
261
+ qc12.get_header ().get_tracking ().update_time_stamp (t0);
262
+ qc21.get_header ().get_tracking ().update_time_stamp (t0);
263
+ randomize (qc11, engine, pdf);
264
+ randomize (qc12, engine, pdf);
265
+ randomize (qc21, engine, pdf);
266
+
267
+ // Create and set up the diagnostic for count
268
+ auto count_diag11 = diag_factory.create (" ConditionalSampling" , comm, params);
269
+ auto count_diag12 = diag_factory.create (" ConditionalSampling" , comm, params);
270
+ auto count_diag21 = diag_factory.create (" ConditionalSampling" , comm, params);
271
+ count_diag11->set_grids (gm);
272
+ count_diag12->set_grids (gm);
273
+ count_diag21->set_grids (gm);
274
+
275
+ // Set the fields for each diagnostic
276
+ count_diag11->set_required_field (qc11);
277
+ count_diag11->initialize (t0, RunType::Initial);
278
+ count_diag11->compute_diagnostic ();
279
+ auto count_diag11_f = count_diag11->get_diagnostic ();
280
+ count_diag11_f.sync_to_host ();
281
+ auto count_diag11_v = count_diag11_f.get_view <const Real *, Host>();
282
+
283
+ count_diag12->set_required_field (qc12);
284
+ count_diag12->initialize (t0, RunType::Initial);
285
+ count_diag12->compute_diagnostic ();
286
+ auto count_diag12_f = count_diag12->get_diagnostic ();
287
+ count_diag12_f.sync_to_host ();
288
+ auto count_diag12_v = count_diag12_f.get_view <const Real *, Host>();
289
+
290
+ count_diag21->set_required_field (qc21);
291
+ count_diag21->initialize (t0, RunType::Initial);
292
+ count_diag21->compute_diagnostic ();
293
+ auto count_diag21_f = count_diag21->get_diagnostic ();
294
+ count_diag21_f.sync_to_host ();
295
+ auto count_diag21_v = count_diag21_f.get_view <const Real **, Host>();
296
+
297
+ auto qc11_v = qc11.get_view <const Real *, Host>();
298
+ auto qc12_v = qc12.get_view <const Real *, Host>();
299
+ auto qc21_v = qc21.get_view <const Real **, Host>();
300
+
301
+ // Check the results - count should be 1.0 where condition is met, fill_value otherwise
302
+ for (int ilev = 0 ; ilev < nlevs; ++ilev) {
303
+ // check count for qc12
304
+ if (qc12_v (ilev) > comp_val) {
305
+ REQUIRE (count_diag12_v (ilev) == 1.0 );
306
+ } else {
307
+ REQUIRE (count_diag12_v (ilev) == fill_value);
308
+ }
309
+ }
310
+
311
+ for (int icol = 0 ; icol < ngcols; ++icol) {
312
+ // Check count for qc11
313
+ if (qc11_v (icol) > comp_val) {
314
+ REQUIRE (count_diag11_v (icol) == 1.0 );
315
+ } else {
316
+ REQUIRE (count_diag11_v (icol) == fill_value);
317
+ }
318
+
319
+ for (int ilev = 0 ; ilev < nlevs; ++ilev) {
320
+ // check count for qc21
321
+ if (qc21_v (icol, ilev) > comp_val) {
322
+ REQUIRE (count_diag21_v (icol, ilev) == 1.0 );
323
+ } else {
324
+ REQUIRE (count_diag21_v (icol, ilev) == fill_value);
325
+ }
326
+ // check count again, but the negative
327
+ if (qc21_v (icol, ilev) <= comp_val) {
328
+ REQUIRE_FALSE (count_diag21_v (icol, ilev) == 1.0 );
329
+ } else {
330
+ REQUIRE_FALSE (count_diag21_v (icol, ilev) == fill_value);
331
+ }
332
+ }
333
+ }
334
+ }
248
335
}
249
336
250
337
} // namespace scream
0 commit comments