8
8
"""
9
9
10
10
from functools import partial
11
- from typing import Any , Dict , Iterable , Optional , Sequence , Tuple , Mapping
11
+ from typing import Sequence , Tuple , Mapping
12
12
13
- import datacube
14
13
import numpy as np
15
- import pandas as pd
16
14
import xarray as xr
17
- from datacube .model import Dataset
18
- from odc .geo .xr import assign_crs
19
- from odc .geo .geobox import GeoBox
20
- from odc .algo .io import load_with_native_transform
21
15
from odc .algo ._percentile import xr_quantile_bands
22
16
from odc .stats .plugins ._registry import register , StatsPluginInterface
23
- from odc .algo ._masking import (
24
- erase_bad ,
25
- enum_to_bool
26
- )
17
+ from odc .algo ._masking import erase_bad , enum_to_bool
18
+
27
19
28
20
class S2Cloudless_percentiles (StatsPluginInterface ):
29
21
NAME = "S2Cloudless_percentiles"
@@ -35,36 +27,37 @@ def __init__(
35
27
self ,
36
28
resampling : str = "cubic" ,
37
29
bands : Sequence [str ] = ["oa_s2cloudless_prob" ],
38
- output_bands : Sequence [str ] = ['oa_s2cloudless_prob_pc_5' , 'oa_s2cloudless_prob_pc_10' ,'oa_s2cloudless_prob_pc_25' ],
30
+ output_bands : Sequence [str ] = [
31
+ "oa_s2cloudless_prob_pc_5" ,
32
+ "oa_s2cloudless_prob_pc_10" ,
33
+ "oa_s2cloudless_prob_pc_25" ,
34
+ ],
39
35
mask_band : str = "oa_s2cloudless_mask" ,
40
36
chunks : Mapping [str , int ] = {"y" : 512 , "x" : 512 },
41
37
group_by : str = "solar_day" ,
42
- nodata_classes : Sequence [str ] = ["nodata" ],
38
+ nodata_classes : Sequence [str ] = ["nodata" ],
43
39
output_dtype : str = "float32" ,
44
40
** kwargs ,
45
41
):
46
-
47
- self .resampling = resampling
42
+
43
+ self .resampling = resampling
48
44
self .bands = bands
49
45
self .output_bands = output_bands
50
46
self .mask_band = mask_band
51
47
self .chunks = chunks
52
48
self .group_by = group_by
53
49
self .resampling = resampling
54
- self .nodata_classes = nodata_classes
50
+ self .nodata_classes = nodata_classes
55
51
self .output_dtype = np .dtype (output_dtype )
56
52
self .output_nodata = np .nan
57
53
58
54
super ().__init__ (
59
- input_bands = tuple (bands ),
60
- resampling = resampling ,
61
- chunks = chunks ,
62
- ** kwargs
55
+ input_bands = tuple (bands ), resampling = resampling , chunks = chunks , ** kwargs
63
56
)
64
57
65
58
@property
66
59
def measurements (self ) -> Tuple [str , ...]:
67
- return ( self .output_bands )
60
+ return self .output_bands
68
61
69
62
def native_transform (self , xx : xr .Dataset ) -> xr .Dataset :
70
63
"""
@@ -79,14 +72,13 @@ def native_transform(self, xx: xr.Dataset) -> xr.Dataset:
79
72
mask = xx [self .mask_band ]
80
73
bad = enum_to_bool (mask , self .nodata_classes )
81
74
82
- #drop mask band
75
+ # drop mask band
83
76
xx = xx .drop_vars ([self .mask_band ])
84
77
85
78
# apply the masks
86
79
xx = erase_bad (xx , bad )
87
80
88
81
return xx
89
-
90
82
91
83
def reduce (self , xx : xr .Dataset ) -> xr .Dataset :
92
84
"""
@@ -99,4 +91,5 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
99
91
100
92
return yy
101
93
94
+
102
95
register ("s2_gm_tools.S2Cloudless_percentiles" , S2Cloudless_percentiles )
0 commit comments