Skip to content

Commit f792ffa

Browse files
author
Emma Ai
committed
make output band name as config options
1 parent d45be2a commit f792ffa

9 files changed

+36
-58
lines changed

odc/stats/plugins/_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
rgb_clamp: Tuple[float, float] = (1.0, 3_000.0),
3030
transform_code: Optional[str] = None,
3131
area_of_interest: Optional[Sequence[float]] = None,
32+
measurements: Optional[Sequence[str]] = None,
3233
):
3334
self.resampling = resampling
3435
self.input_bands = input_bands if input_bands is not None else []
@@ -40,12 +41,14 @@ def __init__(
4041
self.rgb_clamp = rgb_clamp
4142
self.transform_code = transform_code
4243
self.area_of_interest = area_of_interest
44+
self._measurements = measurements
4345
self.dask_worker_plugin = None
4446

4547
@property
46-
@abstractmethod
4748
def measurements(self) -> Tuple[str, ...]:
48-
pass
49+
if self._measurements is None:
50+
raise NotImplementedError("Plugins must provide 'measurements'")
51+
return self._measurements
4952

5053
def native_transform(self, xx: xr.Dataset) -> xr.Dataset:
5154
for var in xx.data_vars:

odc/stats/plugins/lc_fc_wo_a0.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,6 @@ def __init__(
4242
self.ue_threshold = ue_threshold if ue_threshold is not None else 30
4343
self.cloud_filters = cloud_filters if cloud_filters is not None else {}
4444

45-
@property
46-
def measurements(self) -> Tuple[str, ...]:
47-
_measurements = ["veg_frequency", "water_frequency"]
48-
return _measurements
49-
5045
def native_transform(self, xx):
5146
"""
5247
Loads data in its native projection. It performs the following:
@@ -217,12 +212,8 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
217212
attrs = xx.attrs.copy()
218213
attrs["nodata"] = int(NODATA)
219214
data_vars = {
220-
"veg_frequency": xr.DataArray(
221-
max_count_veg, dims=xx["wet"].dims[1:], attrs=attrs
222-
),
223-
"water_frequency": xr.DataArray(
224-
max_count_water, dims=xx["wet"].dims[1:], attrs=attrs
225-
),
215+
k: xr.DataArray(v, dims=xx["wet"].dims[1:], attrs=attrs)
216+
for k, v in zip(self.measurements, [max_count_veg, max_count_water])
226217
}
227218
coords = dict((dim, xx.coords[dim]) for dim in xx["wet"].dims[1:])
228219
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)

odc/stats/plugins/lc_tf_urban.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Plugin of TF urban model in LandCover PipeLine
33
"""
44

5-
from typing import Tuple, Dict, Sequence
5+
from typing import Dict, Sequence
66

77
import os
88
import numpy as np
@@ -91,11 +91,6 @@ def __init__(
9191
else:
9292
self.crop_size = crop_size
9393

94-
@property
95-
def measurements(self) -> Tuple[str, ...]:
96-
_measurements = ["urban_classes"]
97-
return _measurements
98-
9994
def input_data(
10095
self, datasets: Sequence[Dataset], geobox: GeoBox, **kwargs
10196
) -> xr.Dataset:
@@ -219,7 +214,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
219214
attrs = xx.attrs.copy()
220215
attrs["nodata"] = int(NODATA)
221216
dims = list(xx.dims.keys())[:2]
222-
data_vars = {"urban_classes": xr.DataArray(um, dims=dims, attrs=attrs)}
217+
data_vars = {self.measurements[0]: xr.DataArray(um, dims=dims, attrs=attrs)}
223218
coords = {dim: xx.coords[dim] for dim in dims}
224219
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
225220

odc/stats/plugins/lc_treelite_cultivated.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Plugin of RFclassfication cultivated model in LandCover PipeLine
33
"""
44

5-
from typing import Tuple
65
import numpy as np
76
import xarray as xr
87
import dask.array as da
@@ -226,11 +225,6 @@ class StatsCultivatedClass(StatsMLTree):
226225
VERSION = "0.0.1"
227226
PRODUCT_FAMILY = "lccs"
228227

229-
@property
230-
def measurements(self) -> Tuple[str, ...]:
231-
_measurements = ["cultivated"]
232-
return _measurements
233-
234228
def predict(self, input_array):
235229
bands_indices = dict(zip(self.input_bands, np.arange(len(self.input_bands))))
236230
input_features = da.map_blocks(

odc/stats/plugins/lc_treelite_woody.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Plugin of RFregressor woody cover model in LandCover PipeLine
33
"""
44

5-
from typing import Tuple
65
import xarray as xr
76
import dask.array as da
87

@@ -19,11 +18,6 @@ class StatsWoodyCover(StatsMLTree):
1918
VERSION = "0.0.1"
2019
PRODUCT_FAMILY = "lccs"
2120

22-
@property
23-
def measurements(self) -> Tuple[str, ...]:
24-
_measurements = ["woody"]
25-
return _measurements
26-
2721
def predict(self, input_array):
2822
wc = da.map_blocks(
2923
mask_and_predict,

odc/stats/plugins/lc_veg_class_a1.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,6 @@ def __init__(
2525
**kwargs,
2626
):
2727
super().__init__(**kwargs)
28-
self._measurements = (
29-
measurements if measurements is not None else self.input_bands
30-
)
31-
32-
@property
33-
def measurements(self) -> Tuple[str, ...]:
34-
return self._measurements
3528

3629
def native_transform(self, xx):
3730
# reproject cannot work with nodata being int for float
@@ -89,11 +82,6 @@ def __init__(
8982
)
9083
self.output_classes = output_classes
9184

92-
@property
93-
def measurements(self) -> Tuple[str, ...]:
94-
_measurements = ["classes_l3_l4", "water_seasonality"]
95-
return _measurements
96-
9785
def fuser(self, xx):
9886
return xx
9987

@@ -249,12 +237,10 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
249237
attrs = xx.attrs.copy()
250238
attrs["nodata"] = int(NODATA)
251239
data_vars = {
252-
"classes_l3_l4": xr.DataArray(
253-
l3_mask[0], dims=xx["veg_frequency"].dims[1:], attrs=attrs
254-
),
255-
"water_seasonality": xr.DataArray(
256-
water_seasonality[0], dims=xx["veg_frequency"].dims[1:], attrs=attrs
257-
),
240+
k: xr.DataArray(v, dims=xx["veg_frequency"].dims[1:], attrs=attrs)
241+
for k, v in zip(
242+
self.measurements, [l3_mask.squeeze(0), water_seasonality.squeeze(0)]
243+
)
258244
}
259245
coords = dict((dim, xx.coords[dim]) for dim in xx["veg_frequency"].dims[1:])
260246
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)

tests/test_landcover_plugin_a0.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def fc_wo_dataset():
319319
def test_native_transform(fc_wo_dataset, bits):
320320
xx = fc_wo_dataset.copy()
321321
xx["water"] = da.bitwise_or(xx["water"], bits)
322-
stats_veg = StatsVegCount()
322+
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
323323
out_xx = stats_veg.native_transform(xx).compute()
324324

325325
expected_valid = (
@@ -349,7 +349,7 @@ def test_native_transform(fc_wo_dataset, bits):
349349

350350

351351
def test_fusing(fc_wo_dataset):
352-
stats_veg = StatsVegCount()
352+
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
353353
xx = stats_veg.native_transform(fc_wo_dataset)
354354
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)).compute()
355355
valid_index = (
@@ -369,7 +369,7 @@ def test_fusing(fc_wo_dataset):
369369

370370

371371
def test_veg_or_not(fc_wo_dataset):
372-
stats_veg = StatsVegCount()
372+
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
373373
xx = stats_veg.native_transform(fc_wo_dataset)
374374
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
375375
yy = stats_veg._veg_or_not(xx).compute()
@@ -386,7 +386,7 @@ def test_veg_or_not(fc_wo_dataset):
386386

387387

388388
def test_water_or_not(fc_wo_dataset):
389-
stats_veg = StatsVegCount()
389+
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
390390
xx = stats_veg.native_transform(fc_wo_dataset)
391391
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
392392
yy = stats_veg._water_or_not(xx).compute()
@@ -403,7 +403,7 @@ def test_water_or_not(fc_wo_dataset):
403403

404404

405405
def test_reduce(fc_wo_dataset):
406-
stats_veg = StatsVegCount()
406+
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
407407
xx = stats_veg.native_transform(fc_wo_dataset)
408408
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
409409
xx = stats_veg.reduce(xx).compute()
@@ -437,7 +437,7 @@ def test_reduce(fc_wo_dataset):
437437

438438

439439
def test_consecutive_month(consecutive_count):
440-
stats_veg = StatsVegCount()
440+
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
441441
xx = stats_veg._max_consecutive_months(consecutive_count, 255).compute()
442442
expected_value = np.array(
443443
[

tests/test_landcover_plugin_a1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def test_l3_classes(dataset):
135135
"surface": 210,
136136
},
137137
optional_bands=["canopy_cover_class", "elevation"],
138+
measurements=["level_3_4", "water_season"],
138139
)
139140

140141
expected_res = np.array(
@@ -163,6 +164,7 @@ def test_l4_water_seasonality(dataset):
163164
"surface": 210,
164165
},
165166
optional_bands=["canopy_cover_class", "elevation"],
167+
measurements=["level_3_4", "water_season"],
166168
)
167169

168170
wo_fq = np.array(
@@ -208,6 +210,7 @@ def test_reduce(dataset):
208210
"surface": 210,
209211
},
210212
optional_bands=["canopy_cover_class", "elevation"],
213+
measurements=["level_3_4", "water_season"],
211214
)
212215
res = stats_l3.reduce(dataset)
213216

tests/test_rf_models.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def test_preprocess_predict_intput(
421421
cultivated_model_path,
422422
mask_bands,
423423
input_bands=cultivated_input_bands,
424+
measurements=["cultivated"],
424425
)
425426
res = cultivated.preprocess_predict_input(input_datasets)
426427
for r in res:
@@ -440,6 +441,7 @@ def test_cultivated_predict(
440441
cultivated_model_path,
441442
mask_bands,
442443
input_bands=cultivated_input_bands,
444+
measurements=["cultivated"],
443445
)
444446
dask_client.register_plugin(cultivated.dask_worker_plugin)
445447
imgs = cultivated.preprocess_predict_input(input_datasets)
@@ -462,6 +464,7 @@ def test_cultivated_aggregate_results(
462464
cultivated_model_path,
463465
mask_bands,
464466
input_bands=cultivated_input_bands,
467+
measurements=["cultivated"],
465468
)
466469
res = cultivated.aggregate_results_from_group([cultivated_results[0]])
467470
assert (res.compute() == np.array([[112, 255], [111, 112]], dtype="uint8")).all()
@@ -482,6 +485,7 @@ def test_cultivated_reduce(
482485
cultivated_model_path,
483486
mask_bands,
484487
input_bands=cultivated_input_bands,
488+
measurements=["cultivated"],
485489
)
486490
dask_client.register_plugin(cultivated.dask_worker_plugin)
487491
res = cultivated.reduce(input_datasets)
@@ -506,7 +510,11 @@ def test_woody_aggregate_results(
506510
):
507511

508512
woody_cover = StatsWoodyCover(
509-
woody_classes, woody_model_path, mask_bands, input_bands=woody_input_bands
513+
woody_classes,
514+
woody_model_path,
515+
mask_bands,
516+
input_bands=woody_input_bands,
517+
measurements=["woody"],
510518
)
511519
res = woody_cover.aggregate_results_from_group([woody_results[0]])
512520
assert (res.compute() == np.array([[113, 255], [114, 113]], dtype="uint8")).all()
@@ -524,7 +532,11 @@ def test_woody_reduce(
524532
):
525533
woody_inputs = input_datasets.sel(bands=woody_input_bands[:-1])
526534
woody_cover = StatsWoodyCover(
527-
woody_classes, woody_model_path, mask_bands, input_bands=woody_input_bands
535+
woody_classes,
536+
woody_model_path,
537+
mask_bands,
538+
input_bands=woody_input_bands,
539+
measurements=["woody"],
528540
)
529541
dask_client.register_plugin(woody_cover.dask_worker_plugin)
530542
res = woody_cover.reduce(woody_inputs)

0 commit comments

Comments
 (0)