Skip to content

Commit fb9be6c

Browse files
authored
apply_neighborhood: support changing band names via apply_metadata (#1158)
* apply_neighborhood: support changing band names via apply_metadata #1155 * apply_dimension: support changing band names via apply_metadata #1155 * apply_dimension: support changing band names via apply_metadata #1155 * apply_metadata only update metadata in more specific case #1155
1 parent 7cf3857 commit fb9be6c

File tree

5 files changed

+96
-2
lines changed

5 files changed

+96
-2
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ without compromising stable operations.
1212

1313
<!-- start-of-changelog -->
1414

15+
1516
## In progress: 0.66.0
1617

18+
- `apply_neighborhood`/`apply_dimension`: support changing band names via apply_metadata ([#1155](https://github.com/Open-EO/openeo-geopyspark-driver/issues/1155))
1719

1820
## 0.65.0
1921

openeogeotrellis/geopysparkdatacube.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,10 @@ def partitionByKey(spatialkey):
630630
_log.info(f"apply_neighborhood created datacube {metadata}")
631631
return gps.TiledRasterLayer.from_numpy_rdd(gps.LayerType.SPACETIME, numpy_rdd, metadata)
632632

633-
return self.apply_to_levels(partial(rdd_function, self.metadata))
633+
updated_cube_metadata = self.metadata
634+
if ("apply_metadata" in udf_code and runtime.lower() is not "python-jep"):
635+
updated_cube_metadata = self.apply_metadata(udf_code, udf_context)
636+
return self.apply_to_levels(partial(rdd_function, self.metadata), updated_cube_metadata)
634637

635638
@callsite
636639
def chunk_polygon(
@@ -828,8 +831,11 @@ def tile_function(metadata: Metadata,
828831
metadata = GeopysparkDataCube._transform_metadata(rdd.layer_metadata, cellType=CellType.FLOAT32)
829832
return gps.TiledRasterLayer.from_numpy_rdd(rdd.layer_type, numpy_rdd, metadata)
830833

834+
updated_cube_metadata = self.metadata
835+
if ("apply_metadata" in udf_code):
836+
updated_cube_metadata = self.apply_metadata(udf_code, context)
831837
# Apply the UDF to every tile for every zoom level of the pyramid.
832-
return self.apply_to_levels(partial(rdd_function, self.metadata))
838+
return self.apply_to_levels(partial(rdd_function, self.metadata),updated_cube_metadata)
833839

834840
def aggregate_time(self, temporal_window, aggregationfunction) -> Series :
835841
#group keys
@@ -1141,6 +1147,34 @@ def apply_kernel(self, kernel: np.ndarray, factor=1, border=0, replace_invalid=0
11411147
lambda rdd, level: pysc._jvm.org.openeo.geotrellis.OpenEOProcesses().apply_kernel_spatial(rdd,geotrellis_tile))
11421148
return result_collection
11431149

1150+
1151+
def apply_metadata(self,udf_code,context):
1152+
1153+
pysc = gps.get_spark_context()
1154+
metadata = self.metadata
1155+
_log.info(f"run_udf: detected use of apply_metadata to transform: {self.metadata}")
1156+
def get_metadata(x):
1157+
from openeo.udf.run_code import load_module_from_string
1158+
module = load_module_from_string(udf_code)
1159+
functions = list([(k, v) for (k, v) in module.items() if callable(v) and k == "apply_metadata"])
1160+
if len(functions)>=1:
1161+
apply_metadata_func = functions[0][1]
1162+
transformed_metadata = apply_metadata_func(metadata,context)
1163+
return transformed_metadata
1164+
else:
1165+
raise ValueError("run_udf: apply_metadata function not found in the provided code.")
1166+
metadata_list = pysc.parallelize([0]).map(get_metadata).collect()
1167+
result_metadata: GeopysparkCubeMetadata = metadata_list[0]
1168+
1169+
_log.info(f"run_udf: apply_metadata resulted in {result_metadata}")
1170+
if not result_metadata.has_band_dimension():
1171+
raise ValueError(f"run_udf: apply_metadata function should not remove the band dimension, received metadata: {result_metadata}.")
1172+
if not isinstance(result_metadata, GeopysparkCubeMetadata):
1173+
raise ValueError(f"run_udf: apply_metadata function should retain the type of the input metadata object, received: {result_metadata}.")
1174+
1175+
return result_metadata
1176+
1177+
11441178
@callsite
11451179
def apply_neighborhood(
11461180
self, process: dict, *, size: List[dict], overlap: List[dict], context: Optional[dict] = None, env: EvalEnv

tests/conftest.py

+30
Original file line numberDiff line numberDiff line change
@@ -482,3 +482,33 @@ def unload_dummy_packages():
482482
if package in sys.modules:
483483
del sys.modules[package]
484484
importlib.invalidate_caches()
485+
486+
487+
@pytest.fixture
488+
def identity_udf_rename_bands():
489+
udf_code = """
490+
from openeo.metadata import CollectionMetadata
491+
from xarray import DataArray
492+
493+
def apply_metadata(metadata: CollectionMetadata, context: dict) -> CollectionMetadata:
494+
return metadata.rename_labels(
495+
dimension="bands",
496+
target=["computed_band_1", "computed_band_2"]
497+
)
498+
499+
def apply_datacube(cube: DataArray, context: dict) -> DataArray:
500+
return cube
501+
"""
502+
udf_process = {
503+
"udf_process": {
504+
"process_id": "run_udf",
505+
"arguments": {
506+
"data": {
507+
"from_parameter": "data"
508+
},
509+
"udf": udf_code
510+
},
511+
"result": True
512+
},
513+
}
514+
return udf_process

tests/test_apply_dimension.py

+11
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ def test_apply_dimension_bands_udf(imagecollection_with_two_bands_and_three_date
6666
assert_array_almost_equal(input, subresult)
6767

6868

69+
def test_apply_metadata(imagecollection_with_two_bands_and_three_dates, identity_udf_rename_bands):
70+
71+
result = imagecollection_with_two_bands_and_three_dates.apply_dimension(
72+
process=identity_udf_rename_bands, dimension="bands", target_dimension=None, context={}, env=EvalEnv()
73+
)
74+
75+
assert result.metadata.band_names == ["computed_band_1", "computed_band_2"]
76+
result_xarray = result._to_xarray()
77+
assert list(result_xarray.bands.values) == ["computed_band_1", "computed_band_2"]
78+
79+
6980
def test_apply_dimension_invalid_dimension(imagecollection_with_two_bands_and_three_dates,udf_noop):
7081
the_date = datetime.datetime(2017, 9, 25, 11, 37)
7182
with pytest.raises(FeatureUnsupportedException):

tests/test_apply_neighborhood.py

+17
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ def test_apply_neighborhood_overlap_udf(imagecollection_with_two_bands_and_three
6262
assert_array_almost_equal(input, subresult)
6363

6464

65+
def test_apply_metadata(imagecollection_with_two_bands_and_three_dates, identity_udf_rename_bands):
66+
67+
the_date = datetime.datetime(2017, 9, 25, 11, 37)
68+
69+
result = imagecollection_with_two_bands_and_three_dates.apply_neighborhood(
70+
process=identity_udf_rename_bands,
71+
size=[{'dimension': 'x', 'unit': 'px', 'value': 32}, {'dimension': 'y', 'unit': 'px', 'value': 32}],
72+
overlap=[{'dimension': 'x', 'unit': 'px', 'value': 8}, {'dimension': 'y', 'unit': 'px', 'value': 8}],
73+
context={},
74+
env=EvalEnv()
75+
)
76+
assert result.metadata.band_names == ["computed_band_1", "computed_band_2"]
77+
result_xarray = result._to_xarray()
78+
assert list(result_xarray.bands.values) == ["computed_band_1", "computed_band_2"]
79+
80+
81+
6582
def test_apply_neighborhood_on_timeseries(imagecollection_with_two_bands_and_three_dates):
6683
the_date = datetime.datetime(2017, 9, 25, 11, 37)
6784
graph = {

0 commit comments

Comments
 (0)