diff --git a/CHANGELOG.md b/CHANGELOG.md index f99fce721..81cce00b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- When the bands provided to `Connection.load_stac(..., bands=[...])` do not fully match the bands the client extracted from the STAC metadata, a warning will be triggered, but the provided band names will still be used during the client-side preparation of the process graph. This is a pragmatic approach to bridge the gap between differing interpretations of band detection in STAC. Note that this might produce process graphs that are technically invalid and might not work on other backends or future versions of the backend you currently use. It is recommended to consult with the provider of the STAC metadata and openEO backend on the correct and future-proof band names. ([#752](https://github.com/Open-EO/openeo-python-client/issues/752)) + ### Removed ### Fixed diff --git a/openeo/metadata.py b/openeo/metadata.py index ee0e47ea0..24d076412 100644 --- a/openeo/metadata.py +++ b/openeo/metadata.py @@ -218,6 +218,16 @@ def rename_labels(self, target, source) -> Dimension: def rename(self, name) -> Dimension: return BandDimension(name=name, bands=self.bands) + def contains_band(self, band: Union[int, str]) -> bool: + """ + Check if the given band name or index is present in the dimension. + """ + try: + self.band_index(band) + return True + except ValueError: + return False + class GeometryDimension(Dimension): # TODO: how to model/store labels of geometry dimension? @@ -411,6 +421,31 @@ def add_dimension(self, name: str, label: Union[str, float], type: Optional[str] dim = Dimension(type=type or "other", name=name) return self._clone_and_update(dimensions=self._dimensions + [dim]) + def _ensure_band_dimension( + self, *, name: Optional[str] = None, bands: List[Union[Band, str]], warning: str + ) -> CubeMetadata: + """ + Create new CubeMetadata object, ensuring a band dimension with given bands. + This will override any existing band dimension, and is intended for + special cases where pragmatism necessitates to ignore the original metadata. + For example, to overrule badly/incomplete detected band names from STAC metadata. + + .. note:: + It is required to specify a warning message as this method is only intended + to be used as temporary stop-gap solution for use cases that are possibly not future-proof. + Enforcing a warning should make that clear and avoid that users unknowingly depend on + metadata handling behavior that is not guaranteed to be stable. + """ + _log.warning(warning or "ensure_band_dimension: overriding band dimension metadata with user-defined bands.") + if name is None: + # Preserve original band dimension name if possible + name = self.band_dimension.name if self.has_band_dimension() else "bands" + bands = [b if isinstance(b, Band) else Band(name=b) for b in bands] + band_dimension = BandDimension(name=name, bands=bands) + return self._clone_and_update( + dimensions=[d for d in self._dimensions if not isinstance(d, BandDimension)] + [band_dimension] + ) + def drop_dimension(self, name: str = None) -> CubeMetadata: """Create new CubeMetadata object without dropped dimension with given name""" dimension_names = self.dimension_names() @@ -654,13 +689,13 @@ def is_band_asset(asset: pystac.Asset) -> bool: raise ValueError(stac_object) # At least assume there are spatial dimensions - # TODO: are there conditions in which we even should not assume the presence of spatial dimensions? + # TODO #743: are there conditions in which we even should not assume the presence of spatial dimensions? dimensions = [ SpatialDimension(name="x", extent=[None, None]), SpatialDimension(name="y", extent=[None, None]), ] - # TODO: conditionally include band dimension when there was actual indication of band metadata? + # TODO #743: conditionally include band dimension when there was actual indication of band metadata? band_dimension = BandDimension(name="bands", bands=bands) dimensions.append(band_dimension) diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index 8788e4f51..04c936564 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -441,10 +441,27 @@ def load_stac( graph = PGNode("load_stac", arguments=arguments) try: metadata = metadata_from_stac(url) + # TODO: also apply spatial/temporal filters to metadata? + if isinstance(bands, list): - # TODO: also apply spatial/temporal filters to metadata? - metadata = metadata.filter_bands(band_names=bands) - except Exception: + if metadata.has_band_dimension(): + unknown_bands = [b for b in bands if not metadata.band_dimension.contains_band(b)] + if len(unknown_bands) == 0: + # Ideal case: bands requested by user correspond with bands extracted from metadata. + metadata = metadata.filter_bands(band_names=bands) + else: + metadata = metadata._ensure_band_dimension( + bands=bands, + warning=f"The specified bands {bands} in `load_stac` are not a subset of the bands {metadata.band_dimension.band_names} found in the STAC metadata (unknown bands: {unknown_bands}). Working with specified bands as is.", + ) + else: + metadata = metadata._ensure_band_dimension( + name="bands", + bands=bands, + warning=f"Bands {bands} were specified in `load_stac`, but no band dimension was detected in the STAC metadata. Working with band dimension and specified bands.", + ) + + except Exception as e: log.warning(f"Failed to extract cube metadata from STAC URL {url}", exc_info=True) metadata = None return cls(graph=graph, connection=connection, metadata=metadata) diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index a1ea8470e..cf7d5e65a 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -19,7 +19,13 @@ from openeo import BatchJob from openeo.api.process import Parameter from openeo.internal.graph_building import FlatGraphableMixin, PGNode -from openeo.metadata import _PYSTAC_1_9_EXTENSION_INTERFACE, TemporalDimension +from openeo.metadata import ( + _PYSTAC_1_9_EXTENSION_INTERFACE, + Band, + BandDimension, + CubeMetadata, + TemporalDimension, +) from openeo.rest import ( CapabilitiesException, OpenEoApiError, @@ -2737,6 +2743,21 @@ def test_load_result_filters(requests_mock): class TestLoadStac: + @pytest.fixture + def build_stac_ref(self, tmp_path) -> typing.Callable[[dict], str]: + """ + Helper to dump (STAC) data to a temp file and return the path. + """ + # TODO #738 instead of working with local files: real request mocking of STAC resources compatible with pystac? + + def dump(data) -> str: + stac_path = tmp_path / "stac.json" + with stac_path.open("w", encoding="utf8") as f: + json.dump(data, fp=f) + return str(stac_path) + + return dump + def test_basic(self, con120): cube = con120.load_stac("https://provider.test/dataset") assert cube.flat_graph() == { @@ -2890,20 +2911,19 @@ def test_load_stac_from_job_empty_result(self, con120, requests_mock): } @pytest.mark.parametrize("temporal_dim", ["t", "datezz"]) - def test_load_stac_reduce_temporal(self, con120, tmp_path, temporal_dim): - stac_path = tmp_path / "stac.json" - stac_data = StacDummyBuilder.collection( - cube_dimensions={temporal_dim: {"type": "temporal", "extent": ["2024-01-01", "2024-04-04"]}} + def test_load_stac_reduce_temporal(self, con120, build_stac_ref, temporal_dim): + stac_ref = build_stac_ref( + StacDummyBuilder.collection( + cube_dimensions={temporal_dim: {"type": "temporal", "extent": ["2024-01-01", "2024-04-04"]}} + ) ) - # TODO #738 real request mocking of STAC resources compatible with pystac? - stac_path.write_text(json.dumps(stac_data)) - cube = con120.load_stac(str(stac_path)) + cube = con120.load_stac(stac_ref) reduced = cube.reduce_temporal("max") assert reduced.flat_graph() == { "loadstac1": { "process_id": "load_stac", - "arguments": {"url": str(stac_path)}, + "arguments": {"url": stac_ref}, }, "reducedimension1": { "process_id": "reduce_dimension", @@ -2957,19 +2977,174 @@ def test_load_stac_no_cube_extension_temporal_dimension(self, con120, tmp_path, cube = con120.load_stac(str(stac_path)) assert cube.metadata.temporal_dimension == TemporalDimension(name="t", extent=dim_extent) - def test_load_stac_band_filtering(self, con120, tmp_path): - stac_path = tmp_path / "stac.json" - stac_data = StacDummyBuilder.collection( - summaries={"eo:bands": [{"name": "B01"}, {"name": "B02"}, {"name": "B03"}]} + def test_load_stac_default_band_handling(self, dummy_backend, build_stac_ref): + stac_ref = build_stac_ref( + StacDummyBuilder.collection( + # TODO #586 also cover STAC 1.1 style "bands" + summaries={"eo:bands": [{"name": "B01"}, {"name": "B02"}, {"name": "B03"}]} + ) ) - # TODO #738 real request mocking of STAC resources compatible with pystac? - stac_path.write_text(json.dumps(stac_data)) - cube = con120.load_stac(str(stac_path)) + cube = dummy_backend.connection.load_stac(stac_ref) assert cube.metadata.band_names == ["B01", "B02", "B03"] - cube = con120.load_stac(str(stac_path), bands=["B03", "B02"]) - assert cube.metadata.band_names == ["B03", "B02"] + cube.execute() + assert dummy_backend.get_pg("load_stac")["arguments"] == { + "url": stac_ref, + } + + @pytest.mark.parametrize( + "bands, expected_warning", + [ + ( + ["B04"], + "The specified bands ['B04'] in `load_stac` are not a subset of the bands ['B01', 'B02', 'B03'] found in the STAC metadata (unknown bands: ['B04']). Working with specified bands as is.", + ), + ( + ["B03", "B04", "B05"], + "The specified bands ['B03', 'B04', 'B05'] in `load_stac` are not a subset of the bands ['B01', 'B02', 'B03'] found in the STAC metadata (unknown bands: ['B04', 'B05']). Working with specified bands as is.", + ), + (["B03", "B02"], None), + (["B01", "B02", "B03"], None), + ], + ) + def test_load_stac_band_filtering(self, dummy_backend, build_stac_ref, caplog, bands, expected_warning): + stac_ref = build_stac_ref( + StacDummyBuilder.collection( + # TODO #586 also cover STAC 1.1 style "bands" + summaries={"eo:bands": [{"name": "B01"}, {"name": "B02"}, {"name": "B03"}]} + ) + ) + + caplog.set_level(logging.WARNING) + # Test with non-existing bands in the collection metadata + cube = dummy_backend.connection.load_stac(stac_ref, bands=bands) + assert cube.metadata.band_names == bands + if expected_warning is None: + assert caplog.text == "" + else: + assert expected_warning in caplog.text + + cube.execute() + assert dummy_backend.get_pg("load_stac")["arguments"] == { + "url": stac_ref, + "bands": bands, + } + + def test_load_stac_band_filtering_no_band_metadata_default(self, dummy_backend, build_stac_ref, caplog): + stac_ref = build_stac_ref(StacDummyBuilder.collection()) + + cube = dummy_backend.connection.load_stac(stac_ref) + # TODO #743: what should the default list of bands be? + assert cube.metadata.band_names == [] + + cube.execute() + assert dummy_backend.get_pg("load_stac")["arguments"] == { + "url": stac_ref, + } + + @pytest.mark.parametrize( + ["bands", "has_band_dimension", "expected_pg_args", "expected_warning"], + [ + (None, False, {}, None), + ( + ["B02", "B03"], + True, + {"bands": ["B02", "B03"]}, + "Bands ['B02', 'B03'] were specified in `load_stac`, but no band dimension was detected in the STAC metadata. Working with band dimension and specified bands.", + ), + ], + ) + def test_load_stac_band_filtering_no_band_dimension( + self, dummy_backend, build_stac_ref, bands, has_band_dimension, expected_pg_args, expected_warning, caplog + ): + stac_ref = build_stac_ref(StacDummyBuilder.collection()) + + # This is a temporary mock.patch hack to make metadata_from_stac return metadata without a band dimension + # TODO #743: Do this properly through appropriate STAC metadata + from openeo.metadata import metadata_from_stac as orig_metadata_from_stac + + def metadata_from_stac(url: str): + metadata = orig_metadata_from_stac(url=url) + assert metadata.has_band_dimension() + metadata = metadata.drop_dimension("bands") + assert not metadata.has_band_dimension() + return metadata + + with mock.patch("openeo.rest.datacube.metadata_from_stac", new=metadata_from_stac): + cube = dummy_backend.connection.load_stac(stac_ref, bands=bands) + + assert cube.metadata.has_band_dimension() == has_band_dimension + + cube.execute() + assert dummy_backend.get_pg("load_stac")["arguments"] == { + **expected_pg_args, + "url": stac_ref, + } + + if expected_warning: + assert expected_warning in caplog.text + else: + assert not caplog.text + + def test_load_stac_band_filtering_no_band_metadata(self, dummy_backend, build_stac_ref, caplog): + caplog.set_level(logging.WARNING) + stac_ref = build_stac_ref(StacDummyBuilder.collection()) + + cube = dummy_backend.connection.load_stac(stac_ref, bands=["B01", "B02"]) + assert cube.metadata.band_names == ["B01", "B02"] + assert ( + # TODO: better warning than confusing "not a subset of the bands []" ? + "The specified bands ['B01', 'B02'] in `load_stac` are not a subset of the bands [] found in the STAC metadata (unknown bands: ['B01', 'B02']). Working with specified bands as is." + in caplog.text + ) + + cube.execute() + assert dummy_backend.get_pg("load_stac")["arguments"] == { + "url": stac_ref, + "bands": ["B01", "B02"], + } + + @pytest.mark.parametrize( + ["bands", "expected_pg_args", "expected_warning"], + [ + (None, {}, None), + ( + ["B02", "B03"], + {"bands": ["B02", "B03"]}, + "The specified bands ['B02', 'B03'] in `load_stac` are not a subset of the bands ['Bz1', 'Bz2'] found in the STAC metadata (unknown bands: ['B02', 'B03']). Working with specified bands as is", + ), + ], + ) + def test_load_stac_band_filtering_custom_band_dimension( + self, dummy_backend, build_stac_ref, bands, expected_pg_args, expected_warning, caplog + ): + stac_ref = build_stac_ref(StacDummyBuilder.collection()) + + # This is a temporary mock.patch hack to make metadata_from_stac return metadata with a custom band dimension + # TODO #743: Do this properly through appropriate STAC metadata + from openeo.metadata import metadata_from_stac as orig_metadata_from_stac + + def metadata_from_stac(url: str): + return CubeMetadata(dimensions=[BandDimension(name="bandzz", bands=[Band("Bz1"), Band("Bz2")])]) + + with mock.patch("openeo.rest.datacube.metadata_from_stac", new=metadata_from_stac): + cube = dummy_backend.connection.load_stac(stac_ref, bands=bands) + + assert cube.metadata.has_band_dimension() + assert cube.metadata.band_dimension.name == "bandzz" + + cube.execute() + assert dummy_backend.get_pg("load_stac")["arguments"] == { + **expected_pg_args, + "url": stac_ref, + } + + if expected_warning: + assert expected_warning in caplog.text + else: + assert not caplog.text + @pytest.mark.parametrize( "bands", diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 229c659a2..aae162d83 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -100,6 +100,28 @@ def test_band_dimension_band_index(): bdim.band_index("yellow") +def test_band_dimension_contains_band(): + bdim = BandDimension( + name="spectral", + bands=[ + Band("B02", "blue", 0.490), + Band("B03", "green", 0.560), + Band("B04", "red", 0.665), + ], + ) + + # Test band names + assert bdim.contains_band("B02") + assert not bdim.contains_band("B05") + + # Test indexes + assert bdim.contains_band(0) + assert not bdim.contains_band(4) + + # Test common names + assert bdim.contains_band("blue") + assert not bdim.contains_band("yellow") + def test_band_dimension_band_name(): bdim = BandDimension( name="spectral", @@ -806,6 +828,60 @@ def test_cubemetadata_drop_dimension(): metadata.drop_dimension("x") +def test_cubemetadata_ensure_band_dimension_add_bands(): + metadata = CubeMetadata( + dimensions=[ + TemporalDimension(name="t", extent=None), + ] + ) + new = metadata._ensure_band_dimension(bands=["red", "green"], warning="ensure_band_dimension at work") + assert new.has_band_dimension() + assert new.dimension_names() == ["t", "bands"] + assert new.band_names == ["red", "green"] + + +def test_cubemetadata_ensure_band_dimension_add_name_and_bands(): + metadata = CubeMetadata( + dimensions=[ + TemporalDimension(name="t", extent=None), + ] + ) + new = metadata._ensure_band_dimension( + name="bandzz", bands=["red", "green"], warning="ensure_band_dimension at work" + ) + assert new.has_band_dimension() + assert new.dimension_names() == ["t", "bandzz"] + assert new.band_names == ["red", "green"] + + +def test_cubemetadata_ensure_band_dimension_override_bands(): + metadata = CubeMetadata( + dimensions=[ + TemporalDimension(name="t", extent=None), + BandDimension(name="bands", bands=[Band("red"), Band("green")]), + ] + ) + new = metadata._ensure_band_dimension(bands=["tomato", "lettuce"], warning="ensure_band_dimension at work") + assert new.has_band_dimension() + assert new.dimension_names() == ["t", "bands"] + assert new.band_names == ["tomato", "lettuce"] + + +def test_cubemetadata_ensure_band_dimension_override_name_and_bands(): + metadata = CubeMetadata( + dimensions=[ + TemporalDimension(name="t", extent=None), + BandDimension(name="bands", bands=[Band("red"), Band("green")]), + ] + ) + new = metadata._ensure_band_dimension( + name="bandzz", bands=["tomato", "lettuce"], warning="ensure_band_dimension at work" + ) + assert new.has_band_dimension() + assert new.dimension_names() == ["t", "bandzz"] + assert new.band_names == ["tomato", "lettuce"] + + def test_collectionmetadata_subclass(): class MyCollectionMetadata(CollectionMetadata): def __init__(self, metadata: dict, dimensions: List[Dimension] = None, bbox=None):