Skip to content

made load_stac nicer to mismatch in band names #755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- When the bands provided to `Connection.load_stac(..., bands=[...])` do not fully match the bands the client extracted from the STAC metadata, a warning will be triggered, but the provided band names will still be used during the client-side preparation of the process graph. This is a pragmatic approach to bridge the gap between differing interpretations of band detection in STAC. Note that this might produce process graphs that are technically invalid and might not work on other backends or future versions of the backend you currently use. It is recommended to consult with the provider of the STAC metadata and openEO backend on the correct and future-proof band names. ([#752](https://github.com/Open-EO/openeo-python-client/issues/752))

### Removed

### Fixed
Expand Down
49 changes: 47 additions & 2 deletions openeo/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,16 @@ def rename_labels(self, target, source) -> Dimension:
def rename(self, name) -> Dimension:
return BandDimension(name=name, bands=self.bands)

def contains_band(self, band: Union[int, str]) -> bool:
"""
Check if the given band name or index is present in the dimension.
"""
try:
self.band_index(band)
return True
except ValueError:
return False


class GeometryDimension(Dimension):
# TODO: how to model/store labels of geometry dimension?
Expand Down Expand Up @@ -413,6 +423,41 @@ def add_dimension(self, name: str, label: Union[str, float], type: Optional[str]
dim = Dimension(type=type or "other", name=name)
return self._clone_and_update(dimensions=self._dimensions + [dim])

def ensure_band_dimension(
self, *, name: Optional[str] = None, bands: List[Union[Band, str]], warning: str
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really a fan of having a function in metadata that takes a warning as an argument just to log it. Logging could happen outside of this function imo

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I was hesitant about that,
but forcing the caller to think of a warning message is to make clear that this function should only be used exceptionally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case I would rename it to replace_band_dimension or even set_band_dimension as ensure sound like running some checks.
Still not a big fan of the warning argument, but if we keep I'd at least make it optional.

I'm quickly working on a commit that also retains the existing bands so we don't lose fields like wavelength_um

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a note to the documentation to better explain this design

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that for cases like this where there is doubt about the API that is added, simply making it somehow 'private' or 'internal' API can help to be future proof.

) -> CubeMetadata:
"""
Create new CubeMetadata object, ensuring a band dimension with given bands.
This will override any existing band dimension, and is intended for
special cases where pragmatism necessitates to ignore the original metadata.
For example, to overrule badly/incomplete detected band names from STAC metadata.

.. note::
It is required to specify a warning message as this method is only intended
to be used as temporary stop-gap solution for use cases that are possibly not future-proof.
Enforcing a warning should make that clear and avoid that users unknowingly depend on
metadata handling behavior that is not guaranteed to be stable.
"""
_log.warning(warning or "ensure_band_dimension: overriding band dimension metadata with user-defined bands.")
if name is None:
# Preserve original band dimension name if possible
name = self.band_dimension.name if self.has_band_dimension() else "bands"
new_bands = []
if self.has_band_dimension():
for band in bands:
try:
# Preserve original band fields if possible
original_band_idx = self.band_dimension.band_index(band.name if isinstance(band, Band) else band)
new_bands.append(self.band_dimension.bands[original_band_idx])
except ValueError:
new_bands.append(band if isinstance(band, Band) else Band(name=band))
else:
new_bands = [b if isinstance(b, Band) else Band(name=b) for b in bands]
band_dimension = BandDimension(name=name, bands=new_bands)
return self._clone_and_update(
dimensions=[d for d in self._dimensions if not isinstance(d, BandDimension)] + [band_dimension]
)

def drop_dimension(self, name: str = None) -> CubeMetadata:
"""Create new CubeMetadata object without dropped dimension with given name"""
dimension_names = self.dimension_names()
Expand Down Expand Up @@ -656,13 +701,13 @@ def is_band_asset(asset: pystac.Asset) -> bool:
raise ValueError(stac_object)

# At least assume there are spatial dimensions
# TODO: are there conditions in which we even should not assume the presence of spatial dimensions?
# TODO #743: are there conditions in which we even should not assume the presence of spatial dimensions?
dimensions = [
SpatialDimension(name="x", extent=[None, None]),
SpatialDimension(name="y", extent=[None, None]),
]

# TODO: conditionally include band dimension when there was actual indication of band metadata?
# TODO #743: conditionally include band dimension when there was actual indication of band metadata?
band_dimension = BandDimension(name="bands", bands=bands)
dimensions.append(band_dimension)

Expand Down
23 changes: 20 additions & 3 deletions openeo/rest/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,27 @@ def load_stac(
graph = PGNode("load_stac", arguments=arguments)
try:
metadata = metadata_from_stac(url)
# TODO: also apply spatial/temporal filters to metadata?

if isinstance(bands, list):
# TODO: also apply spatial/temporal filters to metadata?
metadata = metadata.filter_bands(band_names=bands)
except Exception:
if metadata.has_band_dimension():
unknown_bands = [b for b in bands if not metadata.band_dimension.contains_band(b)]
if len(unknown_bands) == 0:
# Ideal case: bands requested by user correspond with bands extracted from metadata.
metadata = metadata.filter_bands(band_names=bands)
else:
metadata = metadata.ensure_band_dimension(
bands=bands,
warning=f"The specified bands {bands} in `load_stac` are not a subset of the bands {metadata.band_dimension.band_names} found in the STAC metadata (unknown bands: {unknown_bands}). Working with specified bands as is.",
)
else:
metadata = metadata.ensure_band_dimension(
name="bands",
bands=bands,
warning=f"Bands {bands} were specified in `load_stac`, but no band dimension was detected in the STAC metadata. Working with band dimension and specified bands.",
)

except Exception as e:
log.warning(f"Failed to extract cube metadata from STAC URL {url}", exc_info=True)
metadata = None
return cls(graph=graph, connection=connection, metadata=metadata)
Expand Down
242 changes: 224 additions & 18 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
from openeo import BatchJob
from openeo.api.process import Parameter
from openeo.internal.graph_building import FlatGraphableMixin, PGNode
from openeo.metadata import _PYSTAC_1_9_EXTENSION_INTERFACE, TemporalDimension
from openeo.metadata import (
_PYSTAC_1_9_EXTENSION_INTERFACE,
Band,
BandDimension,
CubeMetadata,
TemporalDimension,
)
from openeo.rest import (
CapabilitiesException,
OpenEoApiError,
Expand Down Expand Up @@ -2737,6 +2743,21 @@ def test_load_result_filters(requests_mock):


class TestLoadStac:
@pytest.fixture
def build_stac_ref(self, tmp_path) -> typing.Callable[[dict], str]:
"""
Helper to dump (STAC) data to a temp file and return the path.
"""
# TODO #738 instead of working with local files: real request mocking of STAC resources compatible with pystac?

def dump(data) -> str:
stac_path = tmp_path / "stac.json"
with stac_path.open("w", encoding="utf8") as f:
json.dump(data, fp=f)
return str(stac_path)

return dump

def test_basic(self, con120):
cube = con120.load_stac("https://provider.test/dataset")
assert cube.flat_graph() == {
Expand Down Expand Up @@ -2890,20 +2911,19 @@ def test_load_stac_from_job_empty_result(self, con120, requests_mock):
}

@pytest.mark.parametrize("temporal_dim", ["t", "datezz"])
def test_load_stac_reduce_temporal(self, con120, tmp_path, temporal_dim):
stac_path = tmp_path / "stac.json"
stac_data = StacDummyBuilder.collection(
cube_dimensions={temporal_dim: {"type": "temporal", "extent": ["2024-01-01", "2024-04-04"]}}
def test_load_stac_reduce_temporal(self, con120, build_stac_ref, temporal_dim):
stac_ref = build_stac_ref(
StacDummyBuilder.collection(
cube_dimensions={temporal_dim: {"type": "temporal", "extent": ["2024-01-01", "2024-04-04"]}}
)
)
# TODO #738 real request mocking of STAC resources compatible with pystac?
stac_path.write_text(json.dumps(stac_data))

cube = con120.load_stac(str(stac_path))
cube = con120.load_stac(stac_ref)
reduced = cube.reduce_temporal("max")
assert reduced.flat_graph() == {
"loadstac1": {
"process_id": "load_stac",
"arguments": {"url": str(stac_path)},
"arguments": {"url": stac_ref},
},
"reducedimension1": {
"process_id": "reduce_dimension",
Expand Down Expand Up @@ -2957,19 +2977,205 @@ def test_load_stac_no_cube_extension_temporal_dimension(self, con120, tmp_path,
cube = con120.load_stac(str(stac_path))
assert cube.metadata.temporal_dimension == TemporalDimension(name="t", extent=dim_extent)

def test_load_stac_band_filtering(self, con120, tmp_path):
stac_path = tmp_path / "stac.json"
stac_data = StacDummyBuilder.collection(
summaries={"eo:bands": [{"name": "B01"}, {"name": "B02"}, {"name": "B03"}]}
def test_load_stac_default_band_handling(self, dummy_backend, build_stac_ref):
stac_ref = build_stac_ref(
StacDummyBuilder.collection(
# TODO #586 also cover STAC 1.1 style "bands"
summaries={
"eo:bands": [
{"name": "B01", "common_name": "coastal"},
{"name": "B02", "common_name": "blue"},
{"name": "B03", "common_name": "green"},
]
}
)
)
# TODO #738 real request mocking of STAC resources compatible with pystac?
stac_path.write_text(json.dumps(stac_data))

cube = con120.load_stac(str(stac_path))
cube = dummy_backend.connection.load_stac(stac_ref)
assert cube.metadata.band_names == ["B01", "B02", "B03"]
assert cube.metadata.bands == [
Band(name="B01", common_name="coastal"),
Band(name="B02", common_name="blue"),
Band(name="B03", common_name="green"),
]

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
"url": stac_ref,
}

@pytest.mark.parametrize(
"bands, expected_warning, expected_result_bands",
[
(
["B04"],
"The specified bands ['B04'] in `load_stac` are not a subset of the bands ['B01', 'B02', 'B03'] found in the STAC metadata (unknown bands: ['B04']). Working with specified bands as is.",
[Band(name="B04")],
),
(
["B03", "B04", "B05"],
"The specified bands ['B03', 'B04', 'B05'] in `load_stac` are not a subset of the bands ['B01', 'B02', 'B03'] found in the STAC metadata (unknown bands: ['B04', 'B05']). Working with specified bands as is.",
[Band(name="B03", common_name="green"), Band(name="B04"), Band(name="B05")],
),
(["B03", "B02"], None, [Band(name="B03", common_name="green"), Band(name="B02", common_name="blue")]),
(
["B01", "B02", "B03"],
None,
[
Band(name="B01", common_name="coastal"),
Band(name="B02", common_name="blue"),
Band(name="B03", common_name="green"),
],
),
],
)
def test_load_stac_band_filtering(
self, dummy_backend, build_stac_ref, caplog, bands, expected_warning, expected_result_bands
):
stac_ref = build_stac_ref(
StacDummyBuilder.collection(
# TODO #586 also cover STAC 1.1 style "bands"
summaries={
"eo:bands": [
{"name": "B01", "common_name": "coastal"},
{"name": "B02", "common_name": "blue"},
{"name": "B03", "common_name": "green"},
]
}
)
)

caplog.set_level(logging.WARNING)
# Test with non-existing bands in the collection metadata
cube = dummy_backend.connection.load_stac(stac_ref, bands=bands)
assert cube.metadata.band_names == bands
if expected_warning is None:
assert caplog.text == ""
else:
assert expected_warning in caplog.text

assert cube.metadata.bands == expected_result_bands

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
"url": stac_ref,
"bands": bands,
}

def test_load_stac_band_filtering_no_band_metadata_default(self, dummy_backend, build_stac_ref, caplog):
stac_ref = build_stac_ref(StacDummyBuilder.collection())

cube = dummy_backend.connection.load_stac(stac_ref)
# TODO #743: what should the default list of bands be?
assert cube.metadata.band_names == []

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
"url": stac_ref,
}

@pytest.mark.parametrize(
["bands", "has_band_dimension", "expected_pg_args", "expected_warning"],
[
(None, False, {}, None),
(
["B02", "B03"],
True,
{"bands": ["B02", "B03"]},
"Bands ['B02', 'B03'] were specified in `load_stac`, but no band dimension was detected in the STAC metadata. Working with band dimension and specified bands.",
),
],
)
def test_load_stac_band_filtering_no_band_dimension(
self, dummy_backend, build_stac_ref, bands, has_band_dimension, expected_pg_args, expected_warning, caplog
):
stac_ref = build_stac_ref(StacDummyBuilder.collection())

# This is a temporary mock.patch hack to make metadata_from_stac return metadata without a band dimension
# TODO #743: Do this properly through appropriate STAC metadata
from openeo.metadata import metadata_from_stac as orig_metadata_from_stac

def metadata_from_stac(url: str):
metadata = orig_metadata_from_stac(url=url)
assert metadata.has_band_dimension()
metadata = metadata.drop_dimension("bands")
assert not metadata.has_band_dimension()
return metadata

with mock.patch("openeo.rest.datacube.metadata_from_stac", new=metadata_from_stac):
cube = dummy_backend.connection.load_stac(stac_ref, bands=bands)

assert cube.metadata.has_band_dimension() == has_band_dimension

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
**expected_pg_args,
"url": stac_ref,
}

if expected_warning:
assert expected_warning in caplog.text
else:
assert not caplog.text

def test_load_stac_band_filtering_no_band_metadata(self, dummy_backend, build_stac_ref, caplog):
caplog.set_level(logging.WARNING)
stac_ref = build_stac_ref(StacDummyBuilder.collection())

cube = dummy_backend.connection.load_stac(stac_ref, bands=["B01", "B02"])
assert cube.metadata.band_names == ["B01", "B02"]
assert (
# TODO: better warning than confusing "not a subset of the bands []" ?
"The specified bands ['B01', 'B02'] in `load_stac` are not a subset of the bands [] found in the STAC metadata (unknown bands: ['B01', 'B02']). Working with specified bands as is."
in caplog.text
)

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
"url": stac_ref,
"bands": ["B01", "B02"],
}

@pytest.mark.parametrize(
["bands", "expected_pg_args", "expected_warning"],
[
(None, {}, None),
(
["B02", "B03"],
{"bands": ["B02", "B03"]},
"The specified bands ['B02', 'B03'] in `load_stac` are not a subset of the bands ['Bz1', 'Bz2'] found in the STAC metadata (unknown bands: ['B02', 'B03']). Working with specified bands as is",
),
],
)
def test_load_stac_band_filtering_custom_band_dimension(
self, dummy_backend, build_stac_ref, bands, expected_pg_args, expected_warning, caplog
):
stac_ref = build_stac_ref(StacDummyBuilder.collection())

# This is a temporary mock.patch hack to make metadata_from_stac return metadata with a custom band dimension
# TODO #743: Do this properly through appropriate STAC metadata
from openeo.metadata import metadata_from_stac as orig_metadata_from_stac

def metadata_from_stac(url: str):
return CubeMetadata(dimensions=[BandDimension(name="bandzz", bands=[Band("Bz1"), Band("Bz2")])])

with mock.patch("openeo.rest.datacube.metadata_from_stac", new=metadata_from_stac):
cube = dummy_backend.connection.load_stac(stac_ref, bands=bands)

assert cube.metadata.has_band_dimension()
assert cube.metadata.band_dimension.name == "bandzz"

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
**expected_pg_args,
"url": stac_ref,
}

if expected_warning:
assert expected_warning in caplog.text
else:
assert not caplog.text

cube = con120.load_stac(str(stac_path), bands=["B03", "B02"])
assert cube.metadata.band_names == ["B03", "B02"]

@pytest.mark.parametrize(
"bands",
Expand Down
Loading