Skip to content

Commit 6a29afa

Browse files
emmaaiEmma Ai
andauthored
Fix bugs in cultivated and woody cover plugins (#149)
* round the predict output to float32 resolution * correct the data type in woody cover * exit gracefully if band is missing for cultivated and woody * fix woody cover aggregation * add the hacky fix * comment the docker file * please docker lint --------- Co-authored-by: Emma Ai <emma.ai@ga.gov.au>
1 parent db360b0 commit 6a29afa

File tree

4 files changed

+27
-4
lines changed

4 files changed

+27
-4
lines changed

docker/Dockerfile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ ENV GDAL_DRIVER_PATH=/env/lib/gdalplugins \
2222
GDAL_DATA=/env/share/gdal \
2323
PATH=/env/bin:$PATH
2424

25+
# here is very hacky fix for the threading issue
26+
# MUST follow up with package owner and further address the issue accordingly
27+
28+
RUN wget -q -O /env/lib/python3.10/site-packages/numexpr/necompiler.py https://raw.githubusercontent.com/emmaai/numexpr/master/numexpr/necompiler.py
29+
2530
WORKDIR /tmp
2631

2732
RUN odc-stats --version

odc/stats/plugins/lc_ml_treelite.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Dict, Sequence, Optional
77

88
import os
9+
import sys
910
import numpy as np
1011
import numexpr as ne
1112
import xarray as xr
@@ -21,6 +22,7 @@
2122
from ._registry import StatsPluginInterface
2223
from ._worker import TreeliteModelPlugin
2324
import tl2cgen
25+
import logging
2426

2527

2628
def mask_and_predict(
@@ -44,6 +46,8 @@ def mask_and_predict(
4446
if block_masked.shape[0] > 0:
4547
dmat = tl2cgen.DMatrix(block_masked)
4648
output_data = predictor.predict(dmat).squeeze(axis=1)
49+
# round the number to float32 resolution
50+
output_data = np.round(output_data, 6)
4751
if ptype == "categorical":
4852
prediction[mask_flat] = output_data.argmax(axis=-1)[..., np.newaxis]
4953
else:
@@ -70,6 +74,7 @@ def __init__(
7074
self.dask_worker_plugin = TreeliteModelPlugin(model_path)
7175
self.output_classes = output_classes
7276
self.mask_bands = mask_bands
77+
self._log = logging.getLogger(__name__)
7378

7479
def input_data(
7580
self, datasets: Sequence[Dataset], geobox: GeoBox, **kwargs
@@ -117,6 +122,7 @@ def input_data(
117122

118123
def preprocess_predict_input(self, xx: xr.Dataset):
119124
images = []
125+
veg_mask = None
120126
for var in xx.data_vars:
121127
image = xx[var].data
122128
if var not in self.mask_bands:
@@ -140,6 +146,9 @@ def preprocess_predict_input(self, xx: xr.Dataset):
140146
**{"_v": int(self.mask_bands[var])},
141147
)
142148

149+
if veg_mask is None:
150+
raise TypeError("Missing Veg Mask")
151+
143152
images = [
144153
da.concatenate([image, veg_mask[..., np.newaxis]], axis=-1).rechunk(
145154
(None, None, image.shape[-1] + veg_mask.shape[-1])
@@ -157,7 +166,12 @@ def aggregate_results_from_group(self, predict_output):
157166
pass
158167

159168
def reduce(self, xx: xr.Dataset) -> xr.Dataset:
160-
images = self.preprocess_predict_input(xx)
169+
try:
170+
images = self.preprocess_predict_input(xx)
171+
except TypeError as e:
172+
self._log.warning(e)
173+
sys.exit(0)
174+
161175
res = []
162176

163177
for image in images:

odc/stats/plugins/lc_treelite_woody.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def aggregate_results_from_group(self, predict_output):
6363
)
6464

6565
if m_size > 1:
66-
predict_output = predict_output.sum(axis=0).astype("int")
66+
predict_output = predict_output.sum(axis=0)
6767

6868
predict_output = expr_eval(
6969
"where((a/nodata)>=_l, nodata, a%nodata)",
7070
{"a": predict_output},
7171
name="summary_over_classes",
72-
dtype="uint8",
72+
dtype="float32",
7373
**{
7474
"_l": m_size,
7575
"nodata": NODATA,
@@ -80,7 +80,7 @@ def aggregate_results_from_group(self, predict_output):
8080
"where((a>0)&(a<nodata), _nw, a)",
8181
{"a": predict_output},
8282
name="output_classes_herbaceous",
83-
dtype="uint8",
83+
dtype="float32",
8484
**{"nodata": NODATA, "_nw": self.output_classes["herbaceous"]},
8585
)
8686

tests/test_rf_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,10 @@ def test_cultivated_reduce(
492492
== np.array([[112, 255], [112, 112]], dtype="uint8")
493493
).all()
494494

495+
with pytest.raises(SystemExit) as excinfo:
496+
cultivated.reduce(input_datasets.drop("classes_l3_l4"))
497+
assert excinfo.value.code == 0
498+
495499

496500
def test_woody_aggregate_results(
497501
woody_input_bands,

0 commit comments

Comments
 (0)