Skip to content

Commit d7307d9

Browse files
author
Emma Ai
committed
add tolerance
1 parent 4480ab1 commit d7307d9

File tree

4 files changed

+26
-27
lines changed

4 files changed

+26
-27
lines changed

docker/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ odc-dscache>=0.2.3
1111
odc-stac @ git+https://github.com/opendatacube/odc-stac@69bdf64
1212

1313
# odc-stac is in PyPI
14-
odc-stats[ows] @ git+https://github.com/opendatacube/odc-stats@eee2ed1
14+
odc-stats[ows] @ git+https://github.com/opendatacube/odc-stats@4480ab1
1515

1616
# For ML
1717
tflite-runtime

odc/stats/plugins/lc_tf_urban.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,15 @@ def aggregate_results_from_group(self, urban_masks):
178178
urban_masks = urban_masks[0]
179179

180180
urban_masks = expr_eval(
181-
"where((a/nodata)>=_l, nodata, a%nodata)",
181+
"where((a/nodata)+0.5>=_l, nodata, a%nodata)",
182182
{"a": urban_masks},
183183
name="mark_nodata",
184184
dtype="float32",
185185
**{"_l": m_size, "nodata": NODATA},
186186
)
187187

188188
urban_masks = expr_eval(
189-
"where((a>0)&(a<nodata), _u, a)",
189+
"where((a>0.5)&(a<nodata), _u, a)",
190190
{"a": urban_masks},
191191
name="output_classes_artificial",
192192
dtype="float32",
@@ -197,7 +197,7 @@ def aggregate_results_from_group(self, urban_masks):
197197
)
198198

199199
urban_masks = expr_eval(
200-
"where(a<=0, _nu, a)",
200+
"where(a<0.5, _nu, a)",
201201
{"a": urban_masks},
202202
name="output_classes_natrual",
203203
dtype="uint8",

odc/stats/plugins/lc_treelite_cultivated.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -164,21 +164,6 @@ def generate_features(input_block, bands_indices):
164164
norm, 1e-8
165165
) # Avoid division by zero, fine if it's nan
166166

167-
# reassemble the array
168-
output_block = np.concatenate(
169-
[output_block, input_block[..., bands_indices["sdev"]][..., np.newaxis]],
170-
axis=-1,
171-
).astype("float32")
172-
# scale edev \in [0, 1]
173-
edev = input_block[..., bands_indices["edev"]] / 1e4
174-
output_block = np.concatenate(
175-
[output_block, edev[..., np.newaxis]], axis=-1
176-
).astype("float32")
177-
output_block = np.concatenate(
178-
[output_block, input_block[..., bands_indices["bcdev"]][..., np.newaxis]],
179-
axis=-1,
180-
).astype("float32")
181-
182167
feature_block = None
183168
for f, p in zip(
184169
[
@@ -193,15 +178,29 @@ def generate_features(input_block, bands_indices):
193178
],
194179
feature_input_indices,
195180
):
196-
ib = f(output_block[..., : bands_indices["nbart_swir_2"] + 1], *p)
181+
ib = f(output_block, *p)
197182
if feature_block is None:
198183
feature_block = ib[..., np.newaxis]
199184
else:
200185
feature_block = np.concatenate(
201186
[feature_block, ib[..., np.newaxis]], axis=-1
202187
)
203-
188+
# reassemble the array
189+
output_block = np.concatenate(
190+
[output_block, input_block[..., bands_indices["sdev"]][..., np.newaxis]],
191+
axis=-1,
192+
).astype("float32")
193+
# scale edev \in [0, 1]
194+
edev = input_block[..., bands_indices["edev"]] / 1e4
195+
output_block = np.concatenate(
196+
[output_block, edev[..., np.newaxis]], axis=-1
197+
).astype("float32")
198+
output_block = np.concatenate(
199+
[output_block, input_block[..., bands_indices["bcdev"]][..., np.newaxis]],
200+
axis=-1,
201+
).astype("float32")
204202
output_block = np.concatenate([output_block, feature_block], axis=-1)
203+
205204
selected_indices = np.r_[
206205
[
207206
bands_indices[k]
@@ -288,23 +287,23 @@ def aggregate_results_from_group(self, predict_output):
288287
predict_output = predict_output.sum(axis=0)
289288

290289
predict_output = expr_eval(
291-
"where((m/nodata)>=_l, nodata, m%nodata)",
290+
"where((m/nodata)+0.5>=_l, nodata, m%nodata)",
292291
{"m": predict_output},
293292
name="mark_nodata",
294293
dtype="float32",
295294
**{"_l": m_size, "nodata": NODATA},
296295
)
297296

298297
predict_output = expr_eval(
299-
"where((m>0)&(m<nodata), _u, m)",
298+
"where((m>0.5)&(m<nodata), _u, m)",
300299
{"m": predict_output},
301300
name="output_classes_cultivated",
302301
dtype="float32",
303302
**{"_u": self.output_classes["cultivated"], "nodata": NODATA},
304303
)
305304

306305
predict_output = expr_eval(
307-
"where(m<=0, _nu, m)",
306+
"where(m<0.5, _nu, m)",
308307
{"m": predict_output},
309308
name="output_classes_natural",
310309
dtype="uint8",

odc/stats/plugins/lc_treelite_woody.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def aggregate_results_from_group(self, predict_output):
6666
predict_output = predict_output.sum(axis=0)
6767

6868
predict_output = expr_eval(
69-
"where((a/nodata)>=_l, nodata, a%nodata)",
69+
"where((a/nodata)+0.5>=_l, nodata, a%nodata)",
7070
{"a": predict_output},
7171
name="summary_over_classes",
7272
dtype="float32",
@@ -77,15 +77,15 @@ def aggregate_results_from_group(self, predict_output):
7777
)
7878

7979
predict_output = expr_eval(
80-
"where((a>0)&(a<nodata), _nw, a)",
80+
"where((a>0.5)&(a<nodata), _nw, a)",
8181
{"a": predict_output},
8282
name="output_classes_herbaceous",
8383
dtype="float32",
8484
**{"nodata": NODATA, "_nw": self.output_classes["herbaceous"]},
8585
)
8686

8787
predict_output = expr_eval(
88-
"where(a<=0, _nw, a)",
88+
"where(a<0.5, _nw, a)",
8989
{"a": predict_output},
9090
name="output_classes_woody",
9191
dtype="uint8",

0 commit comments

Comments
 (0)