Skip to content

Commit 50ed73b

Browse files
author
Emma Ai
committed
move columns of dependencies into config
1 parent daf6e19 commit 50ed73b

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

odc/stats/plugins/lc_level34.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,22 @@
22
Plugin of Module A3 in LandCover PipeLine
33
"""
44

5-
from typing import Optional, Dict
5+
from typing import Optional, Dict, List
66

77
import xarray as xr
88
import s3fs
99
import os
1010
import pandas as pd
1111
import dask.array as da
12+
import logging
1213

1314
from ._registry import StatsPluginInterface, register
1415
from ._utils import rasterize_vector_mask, generate_numexpr_expressions
1516
from odc.stats._algebra import expr_eval
1617
from osgeo import gdal
1718

1819
NODATA = 255
20+
_log = logging.getLogger(__name__)
1921

2022

2123
class StatsLccsLevel4(StatsPluginInterface):
@@ -27,6 +29,7 @@ class StatsLccsLevel4(StatsPluginInterface):
2729
def __init__(
2830
self,
2931
class_def_path: str = None,
32+
class_condition: Dict[str, List] = None,
3033
urban_mask: str = None,
3134
filter_expression: str = None,
3235
mask_threshold: Optional[float] = None,
@@ -43,6 +46,9 @@ def __init__(
4346
elif not os.path.exists(class_def_path):
4447
raise FileNotFoundError(f"{class_def_path} not found")
4548

49+
if class_condition is None:
50+
raise ValueError("Missing input to generate classification conditions")
51+
4652
if urban_mask is None:
4753
raise ValueError("Missing urban mask shapefile")
4854

@@ -54,8 +60,12 @@ def __init__(
5460
raise ValueError("Missing urban mask filter")
5561

5662
self.class_def = pd.read_csv(class_def_path)
57-
cols = list(self.class_def.columns[:6]) + list(self.class_def.columns[9:-6])
58-
self.class_def = self.class_def[cols].astype(str).fillna("nan")
63+
self.class_condition = class_condition
64+
cols = set()
65+
for k, v in self.class_condition.items():
66+
cols |= {k} | set(v)
67+
68+
self.class_def = self.class_def[list(cols)].astype(str).fillna("nan")
5969

6070
self.urban_mask = urban_mask
6171
self.filter_expression = filter_expression
@@ -77,6 +87,7 @@ def classification(self, xx, class_def, con_cols, class_col):
7787
res = da.full(xx.level_3_4.shape, 0, dtype="uint8")
7888

7989
for expression in expressions:
90+
_log.info(expression)
8091
local_dict.update({"res": res})
8192
res = expr_eval(
8293
expression,
@@ -98,9 +109,10 @@ def classification(self, xx, class_def, con_cols, class_col):
98109
return res
99110

100111
def reduce(self, xx: xr.Dataset) -> xr.Dataset:
101-
con_cols = ["level1", "artificial_surface", "cultivated"]
102112
class_col = "level3"
103-
level3 = self.classification(xx, self.class_def, con_cols, class_col)
113+
level3 = self.classification(
114+
xx, self.class_def, self.class_condition[class_col], class_col
115+
)
104116

105117
# apply urban mask
106118
# 215 -> 216 if urban_mask == 0
@@ -119,6 +131,8 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
119131
dtype="uint8",
120132
)
121133

134+
# append level3 to the input dataset so it can be used
135+
# to classify level4
122136
attrs = xx.attrs.copy()
123137
attrs["nodata"] = NODATA
124138
dims = xx.level_3_4.dims[1:]
@@ -127,18 +141,10 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
127141
level3.squeeze(), dims=dims, attrs=attrs, coords=coords
128142
)
129143

130-
con_cols = [
131-
"level1",
132-
"level3",
133-
"woody",
134-
"water_season",
135-
"water_frequency",
136-
"pv_pc_50",
137-
"bs_pc_50",
138-
]
139144
class_col = "level4"
140-
141-
level4 = self.classification(xx, self.class_def, con_cols, class_col)
145+
level4 = self.classification(
146+
xx, self.class_def, self.class_condition[class_col], class_col
147+
)
142148

143149
data_vars = {
144150
k: xr.DataArray(v, dims=dims, attrs=attrs)

0 commit comments

Comments
 (0)