2
2
Plugin of Module A3 in LandCover PipeLine
3
3
"""
4
4
5
- from typing import Optional , Dict
5
+ from typing import Optional , Dict , List
6
6
7
7
import xarray as xr
8
8
import s3fs
9
9
import os
10
10
import pandas as pd
11
11
import dask .array as da
12
+ import logging
12
13
13
14
from ._registry import StatsPluginInterface , register
14
15
from ._utils import rasterize_vector_mask , generate_numexpr_expressions
15
16
from odc .stats ._algebra import expr_eval
16
17
from osgeo import gdal
17
18
18
19
NODATA = 255
20
+ _log = logging .getLogger (__name__ )
19
21
20
22
21
23
class StatsLccsLevel4 (StatsPluginInterface ):
@@ -27,6 +29,7 @@ class StatsLccsLevel4(StatsPluginInterface):
27
29
def __init__ (
28
30
self ,
29
31
class_def_path : str = None ,
32
+ class_condition : Dict [str , List ] = None ,
30
33
urban_mask : str = None ,
31
34
filter_expression : str = None ,
32
35
mask_threshold : Optional [float ] = None ,
@@ -43,6 +46,9 @@ def __init__(
43
46
elif not os .path .exists (class_def_path ):
44
47
raise FileNotFoundError (f"{ class_def_path } not found" )
45
48
49
+ if class_condition is None :
50
+ raise ValueError ("Missing input to generate classification conditions" )
51
+
46
52
if urban_mask is None :
47
53
raise ValueError ("Missing urban mask shapefile" )
48
54
@@ -54,8 +60,12 @@ def __init__(
54
60
raise ValueError ("Missing urban mask filter" )
55
61
56
62
self .class_def = pd .read_csv (class_def_path )
57
- cols = list (self .class_def .columns [:6 ]) + list (self .class_def .columns [9 :- 6 ])
58
- self .class_def = self .class_def [cols ].astype (str ).fillna ("nan" )
63
+ self .class_condition = class_condition
64
+ cols = set ()
65
+ for k , v in self .class_condition .items ():
66
+ cols |= {k } | set (v )
67
+
68
+ self .class_def = self .class_def [list (cols )].astype (str ).fillna ("nan" )
59
69
60
70
self .urban_mask = urban_mask
61
71
self .filter_expression = filter_expression
@@ -77,6 +87,7 @@ def classification(self, xx, class_def, con_cols, class_col):
77
87
res = da .full (xx .level_3_4 .shape , 0 , dtype = "uint8" )
78
88
79
89
for expression in expressions :
90
+ _log .info (expression )
80
91
local_dict .update ({"res" : res })
81
92
res = expr_eval (
82
93
expression ,
@@ -98,9 +109,10 @@ def classification(self, xx, class_def, con_cols, class_col):
98
109
return res
99
110
100
111
def reduce (self , xx : xr .Dataset ) -> xr .Dataset :
101
- con_cols = ["level1" , "artificial_surface" , "cultivated" ]
102
112
class_col = "level3"
103
- level3 = self .classification (xx , self .class_def , con_cols , class_col )
113
+ level3 = self .classification (
114
+ xx , self .class_def , self .class_condition [class_col ], class_col
115
+ )
104
116
105
117
# apply urban mask
106
118
# 215 -> 216 if urban_mask == 0
@@ -119,6 +131,8 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
119
131
dtype = "uint8" ,
120
132
)
121
133
134
+ # append level3 to the input dataset so it can be used
135
+ # to classify level4
122
136
attrs = xx .attrs .copy ()
123
137
attrs ["nodata" ] = NODATA
124
138
dims = xx .level_3_4 .dims [1 :]
@@ -127,18 +141,10 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
127
141
level3 .squeeze (), dims = dims , attrs = attrs , coords = coords
128
142
)
129
143
130
- con_cols = [
131
- "level1" ,
132
- "level3" ,
133
- "woody" ,
134
- "water_season" ,
135
- "water_frequency" ,
136
- "pv_pc_50" ,
137
- "bs_pc_50" ,
138
- ]
139
144
class_col = "level4"
140
-
141
- level4 = self .classification (xx , self .class_def , con_cols , class_col )
145
+ level4 = self .classification (
146
+ xx , self .class_def , self .class_condition [class_col ], class_col
147
+ )
142
148
143
149
data_vars = {
144
150
k : xr .DataArray (v , dims = dims , attrs = attrs )
0 commit comments