Skip to content

made load_stac nicer to mismatch in band names #755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions openeo/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,25 @@ def add_dimension(self, name: str, label: Union[str, float], type: Optional[str]
dim = Dimension(type=type or "other", name=name)
return self._clone_and_update(dimensions=self._dimensions + [dim])

def ensure_band_dimension(
self, *, name: Optional[str] = None, bands: List[Union[Band, str]], warning: str
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really a fan of having a function in metadata that takes a warning as an argument just to log it. Logging could happen outside of this function imo

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I was hesitant about that,
but forcing the caller to think of a warning message is to make clear that this function should only be used exceptionally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case I would rename it to replace_band_dimension or even set_band_dimension as ensure sound like running some checks.
Still not a big fan of the warning argument, but if we keep I'd at least make it optional.

I'm quickly working on a commit that also retains the existing bands so we don't lose fields like wavelength_um

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a note to the documentation to better explain this design

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that for cases like this where there is doubt about the API that is added, simply making it somehow 'private' or 'internal' API can help to be future proof.

) -> CubeMetadata:
"""
Create new CubeMetadata object, ensuring a band dimension with given bands.
This will override any existing band dimension, and is intended for
special cases where pragmatism necessitates to ignore the original metadata.
For example, to overrule badly/incomplete detected band names from STAC metadata.
"""
_log.warning(warning)
if name is None:
# Preserve original band dimension name if possible
name = self.band_dimension.name if self.has_band_dimension() else "bands"
bands = [b if isinstance(b, Band) else Band(name=b) for b in bands]
band_dimension = BandDimension(name=name, bands=bands)
return self._clone_and_update(
dimensions=[d for d in self._dimensions if not isinstance(d, BandDimension)] + [band_dimension]
)

def drop_dimension(self, name: str = None) -> CubeMetadata:
"""Create new CubeMetadata object without dropped dimension with given name"""
dimension_names = self.dimension_names()
Expand Down Expand Up @@ -666,13 +685,13 @@ def is_band_asset(asset: pystac.Asset) -> bool:
raise ValueError(stac_object)

# At least assume there are spatial dimensions
# TODO: are there conditions in which we even should not assume the presence of spatial dimensions?
# TODO #743: are there conditions in which we even should not assume the presence of spatial dimensions?
dimensions = [
SpatialDimension(name="x", extent=[None, None]),
SpatialDimension(name="y", extent=[None, None]),
]

# TODO: conditionally include band dimension when there was actual indication of band metadata?
# TODO #743: conditionally include band dimension when there was actual indication of band metadata?
band_dimension = BandDimension(name="bands", bands=bands)
dimensions.append(band_dimension)

Expand Down
21 changes: 15 additions & 6 deletions openeo/rest/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,14 +444,23 @@ def load_stac(
# TODO: also apply spatial/temporal filters to metadata?

if isinstance(bands, list):
unknown_bands = [b for b in bands if not metadata.band_dimension.contains_band(b)]
if len(unknown_bands) == 0:
metadata = metadata.filter_bands(band_names=bands)
if metadata.has_band_dimension():
unknown_bands = [b for b in bands if not metadata.band_dimension.contains_band(b)]
if len(unknown_bands) == 0:
# Ideal case: bands requested by user correspond with bands extracted from metadata.
metadata = metadata.filter_bands(band_names=bands)
else:
metadata = metadata.ensure_band_dimension(
bands=bands,
warning=f"The specified bands {bands} in `load_stac` are not a subset of the bands {metadata.band_dimension.band_names} found in the STAC metadata (unknown bands: {unknown_bands}). Working with specified bands as is.",
)
else:
logging.warning(
f"The specified bands {bands} are not a subset of the bands {metadata.band_dimension.band_names} found in the STAC metadata (unknown bands: {unknown_bands}). Using specified bands as is."
metadata = metadata.ensure_band_dimension(
name="bands",
bands=bands,
warning=f"Bands {bands} were specified in `load_stac`, but no band dimension was detected in the STAC metadata. Working with band dimension and specified bands.",
)
metadata = metadata.rename_labels(dimension="bands", target=bands)

except Exception as e:
log.warning(f"Failed to extract cube metadata from STAC URL {url}", exc_info=True)
metadata = None
Expand Down
200 changes: 164 additions & 36 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
from openeo import BatchJob
from openeo.api.process import Parameter
from openeo.internal.graph_building import FlatGraphableMixin, PGNode
from openeo.metadata import _PYSTAC_1_9_EXTENSION_INTERFACE, TemporalDimension
from openeo.metadata import (
_PYSTAC_1_9_EXTENSION_INTERFACE,
Band,
BandDimension,
CubeMetadata,
TemporalDimension,
)
from openeo.rest import (
CapabilitiesException,
OpenEoApiError,
Expand Down Expand Up @@ -2737,6 +2743,21 @@ def test_load_result_filters(requests_mock):


class TestLoadStac:
@pytest.fixture
def build_stac_ref(self, tmp_path) -> typing.Callable[[dict], str]:
"""
Helper to dump (STAC) data to a temp file and return the path.
"""
# TODO #738 instead of working with local files: real request mocking of STAC resources compatible with pystac?

def dump(data) -> str:
stac_path = tmp_path / "stac.json"
with stac_path.open("w", encoding="utf8") as f:
json.dump(data, fp=f)
return str(stac_path)

return dump

def test_basic(self, con120):
cube = con120.load_stac("https://provider.test/dataset")
assert cube.flat_graph() == {
Expand Down Expand Up @@ -2890,20 +2911,19 @@ def test_load_stac_from_job_empty_result(self, con120, requests_mock):
}

@pytest.mark.parametrize("temporal_dim", ["t", "datezz"])
def test_load_stac_reduce_temporal(self, con120, tmp_path, temporal_dim):
stac_path = tmp_path / "stac.json"
stac_data = StacDummyBuilder.collection(
cube_dimensions={temporal_dim: {"type": "temporal", "extent": ["2024-01-01", "2024-04-04"]}}
def test_load_stac_reduce_temporal(self, con120, build_stac_ref, temporal_dim):
stac_ref = build_stac_ref(
StacDummyBuilder.collection(
cube_dimensions={temporal_dim: {"type": "temporal", "extent": ["2024-01-01", "2024-04-04"]}}
)
)
# TODO #738 real request mocking of STAC resources compatible with pystac?
stac_path.write_text(json.dumps(stac_data))

cube = con120.load_stac(str(stac_path))
cube = con120.load_stac(stac_ref)
reduced = cube.reduce_temporal("max")
assert reduced.flat_graph() == {
"loadstac1": {
"process_id": "load_stac",
"arguments": {"url": str(stac_path)},
"arguments": {"url": stac_ref},
},
"reducedimension1": {
"process_id": "reduce_dimension",
Expand Down Expand Up @@ -2957,66 +2977,174 @@ def test_load_stac_no_cube_extension_temporal_dimension(self, con120, tmp_path,
cube = con120.load_stac(str(stac_path))
assert cube.metadata.temporal_dimension == TemporalDimension(name="t", extent=dim_extent)

def test_load_stac_default_band_handling(self, dummy_backend, build_stac_ref):
stac_ref = build_stac_ref(
StacDummyBuilder.collection(
# TODO #586 also cover STAC 1.1 style "bands"
summaries={"eo:bands": [{"name": "B01"}, {"name": "B02"}, {"name": "B03"}]}
)
)

cube = dummy_backend.connection.load_stac(stac_ref)
assert cube.metadata.band_names == ["B01", "B02", "B03"]

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
"url": stac_ref,
}

@pytest.mark.parametrize(
"bands, expected_warning",
[
(
["B04"],
"The specified bands ['B04'] are not a subset of the bands ['B01', 'B02', 'B03'] found in the STAC metadata (unknown bands: ['B04']). Using specified bands as is.",
"The specified bands ['B04'] in `load_stac` are not a subset of the bands ['B01', 'B02', 'B03'] found in the STAC metadata (unknown bands: ['B04']). Working with specified bands as is.",
),
(
["B03", "B04", "B05"],
"The specified bands ['B03', 'B04', 'B05'] are not a subset of the bands ['B01', 'B02', 'B03'] found in the STAC metadata (unknown bands: ['B04', 'B05']). Using specified bands as is.",
"The specified bands ['B03', 'B04', 'B05'] in `load_stac` are not a subset of the bands ['B01', 'B02', 'B03'] found in the STAC metadata (unknown bands: ['B04', 'B05']). Working with specified bands as is.",
),
(["B03", "B02"], None),
(["B01", "B02", "B03"], None),
],
)
def test_load_stac_band_filtering(self, con120, tmp_path, caplog, bands, expected_warning):
stac_path = tmp_path / "stac.json"
stac_data = StacDummyBuilder.collection(
summaries={"eo:bands": [{"name": "B01"}, {"name": "B02"}, {"name": "B03"}]}
def test_load_stac_band_filtering(self, dummy_backend, build_stac_ref, caplog, bands, expected_warning):
stac_ref = build_stac_ref(
StacDummyBuilder.collection(
# TODO #586 also cover STAC 1.1 style "bands"
summaries={"eo:bands": [{"name": "B01"}, {"name": "B02"}, {"name": "B03"}]}
)
)
# TODO #738 real request mocking of STAC resources compatible with pystac?
stac_path.write_text(json.dumps(stac_data))

caplog.set_level(logging.WARNING)
# Test with non-existing bands in the collection metadata
cube = con120.load_stac(str(stac_path), bands=bands)
cube = dummy_backend.connection.load_stac(stac_ref, bands=bands)
assert cube.metadata.band_names == bands
if expected_warning is None:
assert caplog.text == ""
else:
assert expected_warning in caplog.text

def test_load_stac_band_filtering_no_requested_bands(self, con120, tmp_path):
stac_path = tmp_path / "stac.json"
stac_data = StacDummyBuilder.collection(
summaries={"eo:bands": [{"name": "B01"}, {"name": "B02"}, {"name": "B03"}]}
)
# TODO #738 real request mocking of STAC resources compatible with pystac?
stac_path.write_text(json.dumps(stac_data))

cube = con120.load_stac(str(stac_path))
assert cube.metadata.band_names == ["B01", "B02", "B03"]
cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
"url": stac_ref,
"bands": bands,
}

def test_load_stac_band_filtering_no_metadata(self, con120, tmp_path, caplog):
stac_path = tmp_path / "stac.json"
stac_data = StacDummyBuilder.collection()
# TODO #738 real request mocking of STAC resources compatible with pystac?
stac_path.write_text(json.dumps(stac_data))
def test_load_stac_band_filtering_no_band_metadata_default(self, dummy_backend, build_stac_ref, caplog):
stac_ref = build_stac_ref(StacDummyBuilder.collection())

cube = con120.load_stac(str(stac_path))
cube = dummy_backend.connection.load_stac(stac_ref)
# TODO #743: what should the default list of bands be?
assert cube.metadata.band_names == []

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
"url": stac_ref,
}

@pytest.mark.parametrize(
["bands", "has_band_dimension", "expected_pg_args", "expected_warning"],
[
(None, False, {}, None),
(
["B02", "B03"],
True,
{"bands": ["B02", "B03"]},
"Bands ['B02', 'B03'] were specified in `load_stac`, but no band dimension was detected in the STAC metadata. Working with band dimension and specified bands.",
),
],
)
def test_load_stac_band_filtering_no_band_dimension(
self, dummy_backend, build_stac_ref, bands, has_band_dimension, expected_pg_args, expected_warning, caplog
):
stac_ref = build_stac_ref(StacDummyBuilder.collection())

# This is a temporary mock.patch hack to make metadata_from_stac return metadata without a band dimension
# TODO #743: Do this properly through appropriate STAC metadata
from openeo.metadata import metadata_from_stac as orig_metadata_from_stac

def metadata_from_stac(url: str):
metadata = orig_metadata_from_stac(url=url)
assert metadata.has_band_dimension()
metadata = metadata.drop_dimension("bands")
assert not metadata.has_band_dimension()
return metadata

with mock.patch("openeo.rest.datacube.metadata_from_stac", new=metadata_from_stac):
cube = dummy_backend.connection.load_stac(stac_ref, bands=bands)

assert cube.metadata.has_band_dimension() == has_band_dimension

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
**expected_pg_args,
"url": stac_ref,
}

if expected_warning:
assert expected_warning in caplog.text
else:
assert not caplog.text

def test_load_stac_band_filtering_no_band_metadata(self, dummy_backend, build_stac_ref, caplog):
caplog.set_level(logging.WARNING)
cube = con120.load_stac(str(stac_path), bands=["B01", "B02"])
stac_ref = build_stac_ref(StacDummyBuilder.collection())

cube = dummy_backend.connection.load_stac(stac_ref, bands=["B01", "B02"])
assert cube.metadata.band_names == ["B01", "B02"]
assert (
"The specified bands ['B01', 'B02'] are not a subset of the bands [] found in the STAC metadata (unknown bands: ['B01', 'B02']). Using specified bands as is."
# TODO: better warning than confusing "not a subset of the bands []" ?
"The specified bands ['B01', 'B02'] in `load_stac` are not a subset of the bands [] found in the STAC metadata (unknown bands: ['B01', 'B02']). Working with specified bands as is."
in caplog.text
)

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
"url": stac_ref,
"bands": ["B01", "B02"],
}

@pytest.mark.parametrize(
["bands", "expected_pg_args", "expected_warning"],
[
(None, {}, None),
(
["B02", "B03"],
{"bands": ["B02", "B03"]},
"The specified bands ['B02', 'B03'] in `load_stac` are not a subset of the bands ['Bz1', 'Bz2'] found in the STAC metadata (unknown bands: ['B02', 'B03']). Working with specified bands as is",
),
],
)
def test_load_stac_band_filtering_custom_band_dimension(
self, dummy_backend, build_stac_ref, bands, expected_pg_args, expected_warning, caplog
):
stac_ref = build_stac_ref(StacDummyBuilder.collection())

# This is a temporary mock.patch hack to make metadata_from_stac return metadata with a custom band dimension
# TODO #743: Do this properly through appropriate STAC metadata
from openeo.metadata import metadata_from_stac as orig_metadata_from_stac

def metadata_from_stac(url: str):
return CubeMetadata(dimensions=[BandDimension(name="bandzz", bands=[Band("Bz1"), Band("Bz2")])])

with mock.patch("openeo.rest.datacube.metadata_from_stac", new=metadata_from_stac):
cube = dummy_backend.connection.load_stac(stac_ref, bands=bands)

assert cube.metadata.has_band_dimension()
assert cube.metadata.band_dimension.name == "bandzz"

cube.execute()
assert dummy_backend.get_pg("load_stac")["arguments"] == {
**expected_pg_args,
"url": stac_ref,
}

if expected_warning:
assert expected_warning in caplog.text
else:
assert not caplog.text


@pytest.mark.parametrize(
"bands",
Expand Down
52 changes: 52 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,58 @@ def test_cubemetadata_drop_dimension():
metadata.drop_dimension("x")


def test_cubemetadata_ensure_band_dimension_add_bands():
metadata = CubeMetadata(
dimensions=[
TemporalDimension(name="t", extent=None),
]
)
new = metadata.ensure_band_dimension(bands=["red", "green"], warning="ensure_band_dimension at work")
assert new.has_band_dimension()
assert new.dimension_names() == ["t", "bands"]
assert new.band_names == ["red", "green"]


def test_cubemetadata_ensure_band_dimension_add_name_and_bands():
metadata = CubeMetadata(
dimensions=[
TemporalDimension(name="t", extent=None),
]
)
new = metadata.ensure_band_dimension(name="bandzz", bands=["red", "green"], warning="ensure_band_dimension at work")
assert new.has_band_dimension()
assert new.dimension_names() == ["t", "bandzz"]
assert new.band_names == ["red", "green"]


def test_cubemetadata_ensure_band_dimension_override_bands():
metadata = CubeMetadata(
dimensions=[
TemporalDimension(name="t", extent=None),
BandDimension(name="bands", bands=[Band("red"), Band("green")]),
]
)
new = metadata.ensure_band_dimension(bands=["tomato", "lettuce"], warning="ensure_band_dimension at work")
assert new.has_band_dimension()
assert new.dimension_names() == ["t", "bands"]
assert new.band_names == ["tomato", "lettuce"]


def test_cubemetadata_ensure_band_dimension_override_name_and_bands():
metadata = CubeMetadata(
dimensions=[
TemporalDimension(name="t", extent=None),
BandDimension(name="bands", bands=[Band("red"), Band("green")]),
]
)
new = metadata.ensure_band_dimension(
name="bandzz", bands=["tomato", "lettuce"], warning="ensure_band_dimension at work"
)
assert new.has_band_dimension()
assert new.dimension_names() == ["t", "bandzz"]
assert new.band_names == ["tomato", "lettuce"]


def test_collectionmetadata_subclass():
class MyCollectionMetadata(CollectionMetadata):
def __init__(self, metadata: dict, dimensions: List[Dimension] = None, bbox=None):
Expand Down