Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions odc/stats/plugins/lc_fc_wo_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class StatsVegCount(StatsPluginInterface):
VERSION = "0.0.1"
PRODUCT_FAMILY = "lccs"

BAD_BITS_MASK = dict(cloud=(1 << 6), cloud_shadow=(1 << 5))
BAD_BITS_MASK = {"cloud": (1 << 6), "cloud_shadow": (1 << 5)}

def __init__(
self,
Expand All @@ -44,7 +44,7 @@ def __init__(

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["veg_frequency"]
_measurements = ["veg_frequency", "water_frequency"]
return _measurements

def native_transform(self, xx):
Expand Down Expand Up @@ -80,10 +80,10 @@ def native_transform(self, xx):
# get valid wo pixels, both dry and wet
data = expr_eval(
"where(a|b, a, _nan)",
dict(a=wet.data, b=valid.data),
{"a": wet.data, "b": valid.data},
name="get_valid_pixels",
dtype="float32",
**dict(_nan=np.nan),
**{"_nan": np.nan},
)

# Pick out the fc pixels that have an unmixing error of less than the threshold
Expand Down Expand Up @@ -111,30 +111,49 @@ def _veg_or_not(self, xx: xr.Dataset):
# otherwise 0
data = expr_eval(
"where((a>b)|(c>b), 1, 0)",
dict(a=xx["pv"].data, c=xx["npv"].data, b=xx["bs"].data),
{"a": xx["pv"].data, "c": xx["npv"].data, "b": xx["bs"].data},
name="get_veg",
dtype="uint8",
)

# mark nans
data = expr_eval(
"where(a!=a, nodata, b)",
dict(a=xx["pv"].data, b=data),
{"a": xx["pv"].data, "b": data},
name="get_veg",
dtype="uint8",
**dict(nodata=int(NODATA)),
**{"nodata": int(NODATA)},
)

# mark water freq >= 0.5 as 0
data = expr_eval(
"where(a>0, 0, b)",
dict(a=xx["wet"].data, b=data),
{"a": xx["wet"].data, "b": data},
name="get_veg",
dtype="uint8",
)

return data

def _water_or_not(self, xx: xr.Dataset):
# mark water freq > 0.5 as 1
data = expr_eval(
"where(a>0.5, 1, 0)",
{"a": xx["wet"].data},
name="get_water",
dtype="uint8",
)

# mark nans
data = expr_eval(
"where(a!=a, nodata, b)",
{"a": xx["wet"].data, "b": data},
name="get_water",
dtype="uint8",
**{"nodata": int(NODATA)},
)
return data

def _max_consecutive_months(self, data, nodata):
nan_mask = da.ones(data.shape[1:], chunks=data.chunks[1:], dtype="bool")
tmp = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
Expand All @@ -144,44 +163,44 @@ def _max_consecutive_months(self, data, nodata):
# +1 if not nodata
tmp = expr_eval(
"where(a==nodata, b, a+b)",
dict(a=t, b=tmp),
{"a": t, "b": tmp},
name="compute_consecutive_month",
dtype="uint8",
**dict(nodata=nodata),
**{"nodata": nodata},
)

# save the max
max_count = expr_eval(
"where(a>b, a, b)",
dict(a=max_count, b=tmp),
{"a": max_count, "b": tmp},
name="compute_consecutive_month",
dtype="uint8",
)

# reset if not veg
tmp = expr_eval(
"where((a<=0), 0, b)",
dict(a=t, b=tmp),
{"a": t, "b": tmp},
name="compute_consecutive_month",
dtype="uint8",
)

# mark nodata
nan_mask = expr_eval(
"where(a==nodata, b, False)",
dict(a=t, b=nan_mask),
{"a": t, "b": nan_mask},
name="mark_nodata",
dtype="bool",
**dict(nodata=nodata),
**{"nodata": nodata},
)

# mark nodata
max_count = expr_eval(
"where(a, nodata, b)",
dict(a=nan_mask, b=max_count),
{"a": nan_mask, "b": max_count},
name="mark_nodata",
dtype="uint8",
**dict(nodata=int(nodata)),
**{"nodata": int(nodata)},
)
return max_count

Expand All @@ -190,14 +209,20 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
xx = xx.groupby("time.month").map(median_ds, dim="spec")

data = self._veg_or_not(xx)
max_count = self._max_consecutive_months(data, NODATA)
max_count_veg = self._max_consecutive_months(data, NODATA)

data = self._water_or_not(xx)
max_count_water = self._max_consecutive_months(data, NODATA)

attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
"veg_frequency": xr.DataArray(
max_count, dims=xx["wet"].dims[1:], attrs=attrs
)
max_count_veg, dims=xx["wet"].dims[1:], attrs=attrs
),
"water_frequency": xr.DataArray(
max_count_water, dims=xx["wet"].dims[1:], attrs=attrs
),
}
coords = dict((dim, xx.coords[dim]) for dim in xx["wet"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_landcover_plugin_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,23 @@ def test_veg_or_not(fc_wo_dataset):
i += 1


def test_water_or_not(fc_wo_dataset):
stats_veg = StatsVegCount()
xx = stats_veg.native_transform(fc_wo_dataset)
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
yy = stats_veg._water_or_not(xx).compute()
valid_index = (
np.array([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2]),
np.array([1, 1, 3, 5, 6, 2, 6, 0, 0, 2, 2, 3, 5, 6]),
np.array([0, 3, 2, 1, 3, 5, 6, 0, 2, 1, 4, 2, 5, 6]),
)
expected_value = np.array([0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0])
i = 0
for idx in zip(*valid_index):
assert yy[idx] == expected_value[i]
i += 1


def test_reduce(fc_wo_dataset):
stats_veg = StatsVegCount()
xx = stats_veg.native_transform(fc_wo_dataset)
Expand All @@ -400,6 +417,20 @@ def test_reduce(fc_wo_dataset):

assert (xx.veg_frequency.data == expected_value).all()

expected_value = np.array(
[
[0, 255, 1, 255, 255, 255, 255],
[0, 255, 255, 0, 255, 255, 255],
[255, 1, 255, 255, 0, 0, 255],
[255, 255, 0, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
[255, 0, 255, 255, 255, 0, 255],
[255, 255, 255, 0, 255, 255, 1],
]
)

assert (xx.water_frequency.data == expected_value).all()


def test_consecutive_month(consecutive_count):
stats_veg = StatsVegCount()
Expand Down
Loading