Skip to content

Commit 8ca489c

Browse files
committed
add cloud probability percentiles plugin
1 parent 7f34c86 commit 8ca489c

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""
2+
Long-term percentiles of S2Cloudless probabilities.
3+
4+
Useful for locating regions persistently misclassified as
5+
cloud by S2Cloudless, which is known to have a high false
6+
positive rate.
7+
8+
"""
9+
10+
from functools import partial
11+
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Mapping
12+
13+
import datacube
14+
import numpy as np
15+
import pandas as pd
16+
import xarray as xr
17+
from datacube.model import Dataset
18+
from odc.geo.xr import assign_crs
19+
from odc.geo.geobox import GeoBox
20+
from odc.algo.io import load_with_native_transform
21+
from odc.algo._percentile import xr_quantile_bands
22+
from odc.stats.plugins._registry import register, StatsPluginInterface
23+
from odc.algo._masking import (
24+
erase_bad,
25+
enum_to_bool
26+
)
27+
28+
class S2Cloudless_percentiles(StatsPluginInterface):
29+
NAME = "S2Cloudless_percentiles"
30+
SHORT_NAME = NAME
31+
VERSION = "1.0.0"
32+
PRODUCT_FAMILY = "percentiles"
33+
34+
def __init__(
35+
self,
36+
resampling: str = "cubic",
37+
bands: Sequence[str] = ["oa_s2cloudless_prob"],
38+
output_bands: Sequence[str] = ['oa_s2cloudless_prob_pc_5', 'oa_s2cloudless_prob_pc_10','oa_s2cloudless_prob_pc_25'],
39+
mask_band: str = "oa_s2cloudless_mask",
40+
chunks: Mapping[str, int] = {"y": 512, "x": 512},
41+
group_by: str = "solar_day",
42+
nodata_classes: Sequence[str] = ["nodata"],
43+
output_dtype: str = "float32",
44+
**kwargs,
45+
):
46+
47+
self.resampling=resampling
48+
self.bands = bands
49+
self.output_bands = output_bands
50+
self.mask_band = mask_band
51+
self.chunks = chunks
52+
self.group_by = group_by
53+
self.resampling = resampling
54+
self.nodata_classes= nodata_classes
55+
self.output_dtype = np.dtype(output_dtype)
56+
self.output_nodata = np.nan
57+
58+
super().__init__(
59+
input_bands=tuple(bands),
60+
resampling=resampling,
61+
chunks=chunks,
62+
**kwargs
63+
)
64+
65+
@property
66+
def measurements(self) -> Tuple[str, ...]:
67+
return (self.output_bands)
68+
69+
def native_transform(self, xx: xr.Dataset) -> xr.Dataset:
70+
"""
71+
erases nodata
72+
"""
73+
74+
# step 1-----------------
75+
if self.mask_band not in xx.data_vars:
76+
return xx
77+
78+
# Erase Data Pixels for which mask == nodata
79+
mask = xx[self.mask_band]
80+
bad = enum_to_bool(mask, self.nodata_classes)
81+
82+
#drop mask band
83+
xx = xx.drop_vars([self.mask_band])
84+
85+
# apply the masks
86+
xx = erase_bad(xx, bad)
87+
88+
return xx
89+
90+
91+
def reduce(self, xx: xr.Dataset) -> xr.Dataset:
92+
"""
93+
Calculate the percentiles of long-term cloud probabilities
94+
95+
"""
96+
97+
# Compute the percentiles of long-term cloud probabilities.
98+
yy = xr_quantile_bands(xx, [0.05, 0.10, 0.25], nodata=np.nan)
99+
100+
return yy
101+
102+
register("s2_gm_tools.S2Cloudless_percentiles", S2Cloudless_percentiles)

0 commit comments

Comments
 (0)