Skip to content

Commit 3d41c9e

Browse files
author
Emma Ai
committed
change vector rasterization to gdal
1 parent 4435edd commit 3d41c9e

File tree

5 files changed

+136
-84
lines changed

5 files changed

+136
-84
lines changed

odc/stats/plugins/_utils.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,42 @@
11
import dask
2-
import fiona
3-
from rasterio import features
4-
5-
6-
def rasterize_vector_mask(shape_file, transform, dst_shape, threshold=None):
7-
with fiona.open(shape_file) as source_ds:
8-
geoms = [s["geometry"] for s in source_ds]
9-
10-
mask = features.rasterize(
11-
geoms,
12-
transform=transform,
13-
out_shape=dst_shape[1:],
14-
all_touched=False,
15-
fill=0,
16-
default_value=1,
17-
dtype="uint8",
2+
from osgeo import gdal, ogr, osr
3+
4+
5+
def rasterize_vector_mask(
6+
shape_file, transform, dst_shape, filter_expression=None, threshold=None
7+
):
8+
source_ds = ogr.Open(shape_file)
9+
source_layer = source_ds.GetLayer()
10+
11+
if filter_expression is not None:
12+
source_layer.SetAttributeFilter(filter_expression)
13+
14+
yt, xt = dst_shape[1:]
15+
no_data = 0
16+
albers = osr.SpatialReference()
17+
albers.ImportFromEPSG(3577)
18+
19+
geotransform = (
20+
transform.c,
21+
transform.a,
22+
transform.b,
23+
transform.f,
24+
transform.d,
25+
transform.e,
1826
)
27+
target_ds = gdal.GetDriverByName("MEM").Create("", xt, yt, gdal.GDT_Byte)
28+
target_ds.SetGeoTransform(geotransform)
29+
target_ds.SetProjection(albers.ExportToWkt())
30+
mask = target_ds.GetRasterBand(1)
31+
mask.SetNoDataValue(no_data)
32+
gdal.RasterizeLayer(target_ds, [1], source_layer, burn_values=[1])
33+
34+
mask = mask.ReadAsArray()
1935

36+
# used by landcover level3 urban
2037
# if valid area >= threshold
2138
# then the whole tile is valid
39+
2240
if threshold is not None:
2341
if mask.sum() > mask.size * threshold:
2442
return dask.array.ones(dst_shape, name=False)

odc/stats/plugins/lc_level34.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class StatsLccsLevel4(StatsPluginInterface):
3939
def __init__(
4040
self,
4141
urban_mask: str = None,
42+
filter_expression: str = None,
4243
mask_threshold: Optional[float] = None,
4344
veg_threshold: Optional[List] = None,
4445
bare_threshold: Optional[List] = None,
@@ -51,7 +52,12 @@ def __init__(
5152
raise ValueError("Missing urban mask shapefile")
5253
if not os.path.exists(urban_mask):
5354
raise FileNotFoundError(f"{urban_mask} not found")
55+
56+
if filter_expression is None:
57+
raise ValueError("Missing urban mask filter")
58+
5459
self.urban_mask = urban_mask
60+
self.filter_expression = filter_expression
5561
self.mask_threshold = mask_threshold
5662

5763
self.veg_threshold = (
@@ -90,8 +96,10 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
9096
self.urban_mask,
9197
xx.geobox.transform,
9298
xx.artificial_surface.shape,
99+
filter_expression=self.filter_expression,
93100
threshold=self.mask_threshold,
94101
)
102+
95103
level3 = lc_level3.lc_level3(xx, urban_mask)
96104

97105
# Vegetation cover

tests/conftest.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
import boto3
55
from moto import mock_aws
66
from odc.stats.plugins import register
7+
import json
8+
import tempfile
9+
import os
10+
import fiona
11+
from fiona.crs import CRS
12+
713
from . import DummyPlugin
814

915
TEST_DIR = pathlib.Path(__file__).parent.absolute()
@@ -148,3 +154,64 @@ def usgs_ls8_sr_definition():
148154
],
149155
}
150156
return definition
157+
158+
159+
@pytest.fixture
160+
def urban_shape():
161+
data = """
162+
{
163+
"type":"FeatureCollection",
164+
"features":[
165+
{
166+
"geometry":{
167+
"type":"Polygon",
168+
"coordinates":[
169+
[
170+
[
171+
0,
172+
0
173+
],
174+
[
175+
0,
176+
100
177+
],
178+
[
179+
100,
180+
100
181+
],
182+
[
183+
100,
184+
0
185+
],
186+
[
187+
0,
188+
0
189+
]
190+
]
191+
]
192+
},
193+
"type":"Feature",
194+
"properties":
195+
{
196+
"name": "mock",
197+
"value": 10
198+
}
199+
}
200+
]
201+
}
202+
"""
203+
data = json.loads(data)["features"][0]
204+
tmpdir = tempfile.mkdtemp()
205+
filename = os.path.join(tmpdir, "test.json")
206+
with fiona.open(
207+
filename,
208+
"w",
209+
driver="GeoJSON",
210+
crs=CRS.from_epsg(3577),
211+
schema={
212+
"geometry": "Polygon",
213+
"properties": {"name": "str", "value": "int"},
214+
},
215+
) as dst:
216+
dst.write(data)
217+
return filename

tests/test_lc_l34.py

Lines changed: 4 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33
import pandas as pd
44
import xarray as xr
55
import dask.array as da
6-
import json
7-
import tempfile
8-
import os
9-
import fiona
10-
from fiona.crs import CRS
116
from datacube.utils.geometry import GeoBox
127
from affine import Affine
138

@@ -18,61 +13,6 @@
1813
NODATA = 255
1914

2015

21-
@pytest.fixture(scope="module")
22-
def urban_shape():
23-
data = """
24-
{
25-
"type":"FeatureCollection",
26-
"features":[
27-
{
28-
"geometry":{
29-
"type":"Polygon",
30-
"coordinates":[
31-
[
32-
[
33-
0,
34-
0
35-
],
36-
[
37-
0,
38-
100
39-
],
40-
[
41-
100,
42-
100
43-
],
44-
[
45-
100,
46-
0
47-
],
48-
[
49-
0,
50-
0
51-
]
52-
]
53-
]
54-
},
55-
"type":"Feature"
56-
}
57-
]
58-
}
59-
"""
60-
data = json.loads(data)["features"][0]
61-
tmpdir = tempfile.mkdtemp()
62-
filename = os.path.join(tmpdir, "test.json")
63-
with fiona.open(
64-
filename,
65-
"w",
66-
driver="GeoJSON",
67-
crs=CRS.from_epsg(3577),
68-
schema={
69-
"geometry": "Polygon",
70-
},
71-
) as dst:
72-
dst.write(data)
73-
return filename
74-
75-
7616
@pytest.fixture(scope="module")
7717
def image_groups():
7818
l34 = np.array(
@@ -220,7 +160,10 @@ def test_l4_classes(image_groups, urban_shape):
220160

221161
expected_l4 = [[95, 97, 93], [97, 96, 96], [93, 93, 93], [93, 93, 93]]
222162
stats_l4 = StatsLccsLevel4(
223-
measurements=["level3", "level4"], urban_mask=urban_shape, mask_threshold=0.3
163+
measurements=["level3", "level4"],
164+
urban_mask=urban_shape,
165+
filter_expression="mock > 9",
166+
mask_threshold=0.3,
224167
)
225168
ds = stats_l4.reduce(image_groups)
226169

tests/test_lc_level3.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import dask.array as da
55

66
from odc.stats.plugins.l34_utils import lc_level3
7+
from odc.stats.plugins._utils import rasterize_vector_mask
8+
from datacube.utils.geometry import GeoBox
9+
from affine import Affine
10+
711
import pytest
812

913
NODATA = 255
@@ -58,18 +62,22 @@ def image_groups():
5862
(np.datetime64("2000-01-01T00"), np.datetime64("2000-01-01")),
5963
]
6064
index = pd.MultiIndex.from_tuples(tuples, names=["time", "solar_day"])
61-
coords = {
62-
"x": np.linspace(10, 20, l34.shape[2]),
63-
"y": np.linspace(0, 5, l34.shape[1]),
64-
}
65+
66+
affine = Affine.translation(10, 0) * Affine.scale(
67+
(20 - 10) / l34.shape[2], (5 - 0) / l34.shape[1]
68+
)
69+
geobox = GeoBox(
70+
crs="epsg:3577", affine=affine, width=l34.shape[2], height=l34.shape[1]
71+
)
72+
coords = geobox.xr_coords()
6573

6674
data_vars = {
6775
"classes_l3_l4": xr.DataArray(
6876
da.from_array(l34, chunks=(1, -1, -1)),
6977
dims=("spec", "y", "x"),
7078
attrs={"nodata": 255},
7179
),
72-
"urban_classes": xr.DataArray(
80+
"artificial_surface": xr.DataArray(
7381
da.from_array(urban, chunks=(1, -1, -1)),
7482
dims=("spec", "y", "x"),
7583
attrs={"nodata": 255},
@@ -85,7 +93,15 @@ def image_groups():
8593
return xx
8694

8795

88-
def test_l3_classes(image_groups):
96+
def test_l3_classes(image_groups, urban_shape):
97+
filter_expression = "mock > 9"
98+
urban_mask = rasterize_vector_mask(
99+
urban_shape,
100+
image_groups.geobox.transform,
101+
image_groups.artificial_surface.shape,
102+
filter_expression=filter_expression,
103+
threshold=0.3,
104+
)
89105

90-
level3_classes = lc_level3.lc_level3(image_groups)
106+
level3_classes = lc_level3.lc_level3(image_groups, urban_mask)
91107
assert (level3_classes == expected_l3_classes).all()

0 commit comments

Comments
 (0)