@@ -37,8 +37,10 @@ void apply_conditional_sampling_1d(
37
37
const std::string &condition_op, const Real &condition_val,
38
38
const Real &fill_value = constants::fill_value<Real>) {
39
39
40
+ // if fill_value is 0, we are counting
41
+ const auto is_counting = (fill_value == 0 );
40
42
const auto output_v = output_field.get_view <Real *>();
41
- const auto mask_v = output_field.get_header ().get_extra_data <Field>(" mask_data" ).get_view <Real *>();
43
+ const auto mask_v = !is_counting ? output_field.get_header ().get_extra_data <Field>(" mask_data" ).get_view <Real *>() : output_v ;
42
44
const auto input_v = input_field.get_view <const Real *>();
43
45
const auto condition_v = condition_field.get_view <const Real *>();
44
46
@@ -59,13 +61,13 @@ void apply_conditional_sampling_1d(
59
61
bool condition_masked = has_condition_mask && (condition_mask_v (idx) == 0 );
60
62
if (input_masked || condition_masked) {
61
63
output_v (idx) = fill_value;
62
- mask_v (idx) = 0 ;
64
+ if (!is_counting) mask_v (idx) = 0 ;
63
65
} else if (evaluate_condition (condition_v (idx), op_code, condition_val)) {
64
66
output_v (idx) = input_v (idx);
65
- mask_v (idx) = 1 ;
67
+ if (!is_counting) mask_v (idx) = 1 ;
66
68
} else {
67
69
output_v (idx) = fill_value;
68
- mask_v (idx) = 0 ;
70
+ if (!is_counting) mask_v (idx) = 0 ;
69
71
}
70
72
});
71
73
}
@@ -76,8 +78,11 @@ void apply_conditional_sampling_2d(
76
78
const std::string &condition_op, const Real &condition_val,
77
79
const Real &fill_value = constants::fill_value<Real>) {
78
80
81
+ // if fill_value is 0, we are counting
82
+ const auto is_counting = (fill_value == 0 );
83
+
79
84
const auto output_v = output_field.get_view <Real **>();
80
- const auto mask_v = output_field.get_header ().get_extra_data <Field>(" mask_data" ).get_view <Real **>();
85
+ const auto mask_v = !is_counting ? output_field.get_header ().get_extra_data <Field>(" mask_data" ).get_view <Real **>() : output_v ;
81
86
const auto input_v = input_field.get_view <const Real **>();
82
87
const auto condition_v = condition_field.get_view <const Real **>();
83
88
@@ -102,13 +107,13 @@ void apply_conditional_sampling_2d(
102
107
bool condition_masked = has_condition_mask && (condition_mask_v (icol, ilev) == 0 );
103
108
if (input_masked || condition_masked) {
104
109
output_v (icol, ilev) = fill_value;
105
- mask_v (icol, ilev) = 0 ;
110
+ if (!is_counting) mask_v (icol, ilev) = 0 ;
106
111
} else if (evaluate_condition (condition_v (icol, ilev), op_code, condition_val)) {
107
112
output_v (icol, ilev) = input_v (icol, ilev);
108
- mask_v (icol, ilev) = 1 ;
113
+ if (!is_counting) mask_v (icol, ilev) = 1 ;
109
114
} else {
110
115
output_v (icol, ilev) = fill_value;
111
- mask_v (icol, ilev) = 0 ;
116
+ if (!is_counting) mask_v (icol, ilev) = 0 ;
112
117
}
113
118
});
114
119
}
@@ -119,8 +124,11 @@ void apply_conditional_sampling_1d_lev(
119
124
const std::string &condition_op, const Real &condition_val,
120
125
const Real &fill_value = constants::fill_value<Real>) {
121
126
127
+ // if fill_value is 0, we are counting
128
+ const auto is_counting = (fill_value == 0 );
129
+
122
130
const auto output_v = output_field.get_view <Real *>();
123
- const auto mask_v = output_field.get_header ().get_extra_data <Field>(" mask_data" ).get_view <Real *>();
131
+ const auto mask_v = !is_counting ? output_field.get_header ().get_extra_data <Field>(" mask_data" ).get_view <Real *>() : output_v ;
124
132
const auto input_v = input_field.get_view <const Real *>();
125
133
126
134
// Try to get input mask, if present
@@ -140,13 +148,13 @@ void apply_conditional_sampling_1d_lev(
140
148
bool input_masked = has_input_mask && (input_mask_v (idx) == 0 );
141
149
if (input_masked) {
142
150
output_v (idx) = fill_value;
143
- mask_v (idx) = 0 ;
151
+ if (!is_counting) mask_v (idx) = 0 ;
144
152
} else if (evaluate_condition (level_idx, op_code, condition_val)) {
145
153
output_v (idx) = input_v (idx);
146
- mask_v (idx) = 1 ;
154
+ if (!is_counting) mask_v (idx) = 1 ;
147
155
} else {
148
156
output_v (idx) = fill_value;
149
- mask_v (idx) = 0 ;
157
+ if (!is_counting) mask_v (idx) = 0 ;
150
158
}
151
159
});
152
160
}
@@ -157,8 +165,11 @@ void apply_conditional_sampling_2d_lev(
157
165
const std::string &condition_op, const Real &condition_val,
158
166
const Real &fill_value = constants::fill_value<Real>) {
159
167
168
+ // if fill_value is 0, we are counting
169
+ const auto is_counting = (fill_value == 0 );
170
+
160
171
const auto output_v = output_field.get_view <Real **>();
161
- const auto mask_v = output_field.get_header ().get_extra_data <Field>(" mask_data" ).get_view <Real **>();
172
+ const auto mask_v = !is_counting ? output_field.get_header ().get_extra_data <Field>(" mask_data" ).get_view <Real **>() : output_v ;
162
173
const auto input_v = input_field.get_view <const Real **>();
163
174
164
175
// Try to get input mask, if present
@@ -181,13 +192,13 @@ void apply_conditional_sampling_2d_lev(
181
192
bool input_masked = has_input_mask && (input_mask_v (icol, ilev) == 0 );
182
193
if (input_masked) {
183
194
output_v (icol, ilev) = fill_value;
184
- mask_v (icol, ilev) = 0 ;
195
+ if (!is_counting) mask_v (icol, ilev) = 0 ;
185
196
} else if (evaluate_condition (level_idx, op_code, condition_val)) {
186
197
output_v (icol, ilev) = input_v (icol, ilev);
187
- mask_v (icol, ilev) = 1 ;
198
+ if (!is_counting) mask_v (icol, ilev) = 1 ;
188
199
} else {
189
200
output_v (icol, ilev) = fill_value;
190
- mask_v (icol, ilev) = 0 ;
201
+ if (!is_counting) mask_v (icol, ilev) = 0 ;
191
202
}
192
203
});
193
204
}
@@ -259,17 +270,14 @@ void ConditionalSampling::initialize_impl(const RunType /*run_type*/) {
259
270
260
271
const auto var_fill_value = constants::fill_value<Real>;
261
272
m_mask_val = m_params.get <double >(" mask_value" , var_fill_value);
262
-
263
- m_diagnostic_output.get_header ().set_extra_data (" mask_data" , diag_mask);
264
- m_diagnostic_output.get_header ().set_extra_data (" mask_value" , m_mask_val);
265
-
273
+ if (m_input_f != " count " ) {
274
+ m_diagnostic_output.get_header ().set_extra_data (" mask_data" , diag_mask);
275
+ m_diagnostic_output.get_header ().set_extra_data (" mask_value" , m_mask_val);
276
+ }
266
277
// Special case: if the input field is "count", let's create a field of 1s
267
278
if (m_input_f == " count" ) {
268
279
ones = m_diagnostic_output.clone (" count_ones" );
269
280
ones.deep_copy (1.0 );
270
- auto ones_mask = ones.clone (" count_ones_mask" );
271
- ones.get_header ().set_extra_data (" mask_data" , ones_mask);
272
- ones.get_header ().set_extra_data (" mask_value" , m_mask_val);
273
281
}
274
282
275
283
// Special case: if condition field is "lev", we don't need to check layout compatibility
@@ -308,7 +316,7 @@ void ConditionalSampling::compute_diagnostic_impl() {
308
316
" Valid operators are: eq, ==, ne, !=, gt, >, ge, >=, lt, <, le, <=\n " );
309
317
310
318
// Get the fill value from constants
311
- const Real fill_value = m_mask_val;
319
+ const Real fill_value = (m_input_f == " count " ) ? 0.0 : m_mask_val;
312
320
313
321
// Determine field layout and apply appropriate conditional sampling
314
322
const auto &layout = f.get_header ().get_identifier ().get_layout ();
0 commit comments