|
| 1 | +""" |
| 2 | +Long-term percentiles of S2Cloudless probabilities. |
| 3 | +
|
| 4 | +Useful for locating regions persistently misclassified as |
| 5 | +cloud by S2Cloudless, which is known to have a high false |
| 6 | +positive rate. |
| 7 | +
|
| 8 | +""" |
| 9 | + |
| 10 | +from functools import partial |
| 11 | +from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Mapping |
| 12 | + |
| 13 | +import datacube |
| 14 | +import numpy as np |
| 15 | +import pandas as pd |
| 16 | +import xarray as xr |
| 17 | +from datacube.model import Dataset |
| 18 | +from odc.geo.xr import assign_crs |
| 19 | +from odc.geo.geobox import GeoBox |
| 20 | +from odc.algo.io import load_with_native_transform |
| 21 | +from odc.algo._percentile import xr_quantile_bands |
| 22 | +from odc.stats.plugins._registry import register, StatsPluginInterface |
| 23 | +from odc.algo._masking import ( |
| 24 | + erase_bad, |
| 25 | + enum_to_bool |
| 26 | +) |
| 27 | + |
| 28 | +class S2Cloudless_percentiles(StatsPluginInterface): |
| 29 | + NAME = "S2Cloudless_percentiles" |
| 30 | + SHORT_NAME = NAME |
| 31 | + VERSION = "1.0.0" |
| 32 | + PRODUCT_FAMILY = "percentiles" |
| 33 | + |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + resampling: str = "cubic", |
| 37 | + bands: Sequence[str] = ["oa_s2cloudless_prob"], |
| 38 | + output_bands: Sequence[str] = ['oa_s2cloudless_prob_pc_5', 'oa_s2cloudless_prob_pc_10','oa_s2cloudless_prob_pc_25'], |
| 39 | + mask_band: str = "oa_s2cloudless_mask", |
| 40 | + chunks: Mapping[str, int] = {"y": 512, "x": 512}, |
| 41 | + group_by: str = "solar_day", |
| 42 | + nodata_classes: Sequence[str] = ["nodata"], |
| 43 | + output_dtype: str = "float32", |
| 44 | + **kwargs, |
| 45 | + ): |
| 46 | + |
| 47 | + self.resampling=resampling |
| 48 | + self.bands = bands |
| 49 | + self.output_bands = output_bands |
| 50 | + self.mask_band = mask_band |
| 51 | + self.chunks = chunks |
| 52 | + self.group_by = group_by |
| 53 | + self.resampling = resampling |
| 54 | + self.nodata_classes= nodata_classes |
| 55 | + self.output_dtype = np.dtype(output_dtype) |
| 56 | + self.output_nodata = np.nan |
| 57 | + |
| 58 | + super().__init__( |
| 59 | + input_bands=tuple(bands), |
| 60 | + resampling=resampling, |
| 61 | + chunks=chunks, |
| 62 | + **kwargs |
| 63 | + ) |
| 64 | + |
| 65 | + @property |
| 66 | + def measurements(self) -> Tuple[str, ...]: |
| 67 | + return (self.output_bands) |
| 68 | + |
| 69 | + def native_transform(self, xx: xr.Dataset) -> xr.Dataset: |
| 70 | + """ |
| 71 | + erases nodata |
| 72 | + """ |
| 73 | + |
| 74 | + # step 1----------------- |
| 75 | + if self.mask_band not in xx.data_vars: |
| 76 | + return xx |
| 77 | + |
| 78 | + # Erase Data Pixels for which mask == nodata |
| 79 | + mask = xx[self.mask_band] |
| 80 | + bad = enum_to_bool(mask, self.nodata_classes) |
| 81 | + |
| 82 | + #drop mask band |
| 83 | + xx = xx.drop_vars([self.mask_band]) |
| 84 | + |
| 85 | + # apply the masks |
| 86 | + xx = erase_bad(xx, bad) |
| 87 | + |
| 88 | + return xx |
| 89 | + |
| 90 | + |
| 91 | + def reduce(self, xx: xr.Dataset) -> xr.Dataset: |
| 92 | + """ |
| 93 | + Calculate the percentiles of long-term cloud probabilities |
| 94 | +
|
| 95 | + """ |
| 96 | + |
| 97 | + # Compute the percentiles of long-term cloud probabilities. |
| 98 | + yy = xr_quantile_bands(xx, [0.05, 0.10, 0.25], nodata=np.nan) |
| 99 | + |
| 100 | + return yy |
| 101 | + |
| 102 | +register("s2_gm_tools.S2Cloudless_percentiles", S2Cloudless_percentiles) |
0 commit comments