Skip to content

Commit 97df5ad

Browse files
emmaaiEmma Ai
andauthored
correct veg/non logic in landcover veg frequency (#176)
* correct veg/non logic * fix veg frequency test --------- Co-authored-by: Emma Ai <emma.ai@ga.gov.au>
1 parent d04b1c0 commit 97df5ad

File tree

3 files changed

+46
-56
lines changed

3 files changed

+46
-56
lines changed

odc/stats/plugins/lc_fc_wo_a0.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ class StatsVegCount(StatsPluginInterface):
3434
def __init__(
3535
self,
3636
ue_threshold: Optional[int] = None,
37+
veg_threshold: Optional[int] = None,
3738
cloud_filters: Dict[str, Iterable[Tuple[str, int]]] = None,
3839
**kwargs,
3940
):
4041
super().__init__(input_bands=["water", "pv", "bs", "npv", "ue"], **kwargs)
4142

4243
self.ue_threshold = ue_threshold if ue_threshold is not None else 30
44+
self.veg_threshold = veg_threshold if veg_threshold is not None else 2
4345
self.cloud_filters = cloud_filters if cloud_filters is not None else {}
4446

4547
def native_transform(self, xx):
@@ -63,15 +65,6 @@ def native_transform(self, xx):
6365
# clear dry pixels
6466
clear = xx["water"].data == 0
6567

66-
# get "valid" wo pixels, both dry and wet used in veg_frequency
67-
wet_valid = expr_eval(
68-
"where(a|b, a, _nan)",
69-
{"a": wet, "b": valid},
70-
name="get_valid_pixels",
71-
dtype="float32",
72-
**{"_nan": np.nan},
73-
)
74-
7568
# get "clear" wo pixels, both dry and wet used in water_frequency
7669
wet_clear = expr_eval(
7770
"where(a|b, a, _nan)",
@@ -101,13 +94,7 @@ def native_transform(self, xx):
10194
dtype="float32",
10295
**{"_nan": np.nan},
10396
)
104-
wet_valid = expr_eval(
105-
"where(b>0, _nan, a)",
106-
{"a": wet_valid, "b": raw_mask.data},
107-
name="get_valid_pixels",
108-
dtype="float32",
109-
**{"_nan": np.nan},
110-
)
97+
11198
xx = xx.drop_vars(["water"])
11299

113100
# Pick out the fc pixels that have an unmixing error of less than the threshold
@@ -124,9 +111,6 @@ def native_transform(self, xx):
124111
xx = keep_good_only(xx, valid, nodata=NODATA)
125112
xx = to_float(xx, dtype="float32")
126113

127-
xx["wet_valid"] = xr.DataArray(
128-
wet_valid, dims=xx["pv"].dims, coords=xx["pv"].coords
129-
)
130114
xx["wet_clear"] = xr.DataArray(
131115
wet_clear, dims=xx["pv"].dims, coords=xx["pv"].coords
132116
)
@@ -135,16 +119,14 @@ def native_transform(self, xx):
135119

136120
def fuser(self, xx):
137121

138-
wet_valid = xx["wet_valid"]
139122
wet_clear = xx["wet_clear"]
140123

141124
xx = _xr_fuse(
142-
xx.drop_vars(["wet_valid", "wet_clear"]),
125+
xx.drop_vars(["wet_clear"]),
143126
partial(_fuse_mean_np, nodata=np.nan),
144127
"",
145128
)
146129

147-
xx["wet_valid"] = _nodata_fuser(wet_valid, nodata=np.nan)
148130
xx["wet_clear"] = _nodata_fuser(wet_clear, nodata=np.nan)
149131

150132
return xx
@@ -168,14 +150,6 @@ def _veg_or_not(self, xx: xr.Dataset):
168150
**{"nodata": int(NODATA)},
169151
)
170152

171-
# mark water freq >= 0.5 as 0
172-
data = expr_eval(
173-
"where(a>0, 0, b)",
174-
{"a": xx["wet_valid"].data, "b": data},
175-
name="get_veg",
176-
dtype="uint8",
177-
)
178-
179153
return data
180154

181155
def _water_or_not(self, xx: xr.Dataset):
@@ -262,8 +236,30 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
262236

263237
xx = xx.groupby("time.month").map(median_ds, dim="spec")
264238

265-
data = self._veg_or_not(xx)
266-
max_count_veg = self._max_consecutive_months(data, NODATA)
239+
# consecutive observation of veg
240+
veg_data = self._veg_or_not(xx)
241+
max_count_veg = self._max_consecutive_months(veg_data, NODATA)
242+
243+
# consecutive observation of non-veg
244+
non_veg_data = expr_eval(
245+
"where(a<nodata, 1-a, nodata)",
246+
{"a": veg_data},
247+
name="invert_veg",
248+
dtype="uint8",
249+
**{"nodata": NODATA},
250+
)
251+
max_count_non_veg = self._max_consecutive_months(non_veg_data, NODATA)
252+
253+
# non-veg < threshold implies veg >= threshold
254+
# implies any "wet" area potentially veg
255+
256+
max_count_veg = expr_eval(
257+
"where((a<_v)&(b<_v), _v, b)",
258+
{"a": max_count_non_veg, "b": max_count_veg},
259+
name="clip_veg",
260+
dtype="uint8",
261+
**{"_v": self.veg_threshold},
262+
)
267263

268264
data = self._water_or_not(xx)
269265
max_count_water = self._max_consecutive_months(data, NODATA, normalize=True)

odc/stats/plugins/lc_ml_treelite.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def __init__(
7676
self.dask_worker_plugin = TreeliteModelPlugin(model_path)
7777
self.output_classes = output_classes
7878
self.mask_bands = mask_bands
79-
self.temporal_coverage = temporal_coverage
79+
self.temporal_coverage = (
80+
temporal_coverage if temporal_coverage is not None else {}
81+
)
8082
self._log = logging.getLogger(__name__)
8183

8284
def input_data(
@@ -160,8 +162,8 @@ def convert_dtype(var):
160162

161163
for var in xx.data_vars:
162164
if var not in self.mask_bands:
163-
if self.temporal_coverage is not None:
164-
# filter and impute by sensors
165+
# filter and impute by sensors
166+
if self.temporal_coverage.get(var) is not None:
165167
temporal_range = [
166168
DateTimeRange(v) for v in self.temporal_coverage.get(var)
167169
]

tests/test_landcover_plugin_a0.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -322,17 +322,8 @@ def test_native_transform(fc_wo_dataset, bits):
322322
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
323323
out_xx = stats_veg.native_transform(xx).compute()
324324

325-
expected_valid = (
326-
np.array([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]),
327-
np.array([1, 1, 3, 5, 6, 2, 6, 2, 2, 5, 6, 0, 0, 2, 3]),
328-
np.array([0, 3, 2, 1, 3, 5, 6, 1, 4, 5, 6, 0, 2, 4, 2]),
329-
)
330-
result = np.where(out_xx["wet_valid"].data == out_xx["wet_valid"].data)
331-
for a, b in zip(expected_valid, result):
332-
assert (a == b).all()
333-
334325
expected_valid = (np.array([1, 2, 3]), np.array([6, 2, 0]), np.array([6, 1, 2]))
335-
result = np.where(out_xx["wet_valid"].data == 1)
326+
result = np.where(out_xx["wet_clear"].data == 1)
336327

337328
for a, b in zip(expected_valid, result):
338329
assert (a == b).all()
@@ -374,11 +365,11 @@ def test_veg_or_not(fc_wo_dataset):
374365
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
375366
yy = stats_veg._veg_or_not(xx).compute()
376367
valid_index = (
377-
np.array([0, 0, 1, 2, 2, 2, 2, 2]),
378-
np.array([1, 5, 6, 0, 0, 2, 2, 3]),
379-
np.array([0, 1, 6, 0, 2, 1, 4, 2]),
368+
np.array([0, 0, 2, 2, 2]),
369+
np.array([1, 5, 0, 2, 3]),
370+
np.array([0, 1, 0, 4, 2]),
380371
)
381-
expected_value = np.array([1, 1, 0, 1, 0, 0, 1, 1])
372+
expected_value = np.array([1, 1, 1, 1, 1])
382373
i = 0
383374
for idx in zip(*valid_index):
384375
assert yy[idx] == expected_value[i]
@@ -409,14 +400,15 @@ def test_reduce(fc_wo_dataset):
409400
xx = stats_veg.reduce(xx).compute()
410401
expected_value = np.array(
411402
[
412-
[1, 255, 0, 255, 255, 255, 255],
413-
[1, 255, 255, 255, 255, 255, 255],
414-
[255, 0, 255, 255, 1, 255, 255],
415-
[255, 255, 1, 255, 255, 255, 255],
403+
[2, 255, 255, 255, 255, 255, 255],
404+
[2, 255, 255, 255, 255, 255, 255],
405+
[255, 255, 255, 255, 2, 255, 255],
406+
[255, 255, 2, 255, 255, 255, 255],
416407
[255, 255, 255, 255, 255, 255, 255],
417-
[255, 1, 255, 255, 255, 255, 255],
418-
[255, 255, 255, 255, 255, 255, 0],
419-
]
408+
[255, 2, 255, 255, 255, 255, 255],
409+
[255, 255, 255, 255, 255, 255, 255],
410+
],
411+
dtype="uint8",
420412
)
421413

422414
assert (xx.veg_frequency.data == expected_value).all()

0 commit comments

Comments
 (0)