From 62700029d63fd10fc60d14ef2a597157f0a27688 Mon Sep 17 00:00:00 2001 From: "Peter A. Jonsson" Date: Thu, 11 Sep 2025 09:36:45 +0200 Subject: [PATCH] aws: import odc-cloud modules This is the remaining synchronous S3 functionality in odc-cloud that is used and isn't already imported in datacube. I have added some partial type annotations that I know are correct, and done a light reformat+lint fixes. --- datacube/utils/aws/__init__.py | 21 +++ datacube/utils/aws/inventory.py | 115 ++++++++++++++++ datacube/utils/aws/queue.py | 224 ++++++++++++++++++++++++++++++++ tests/test_utils_aws.py | 106 +++++++++++++++ 4 files changed, 466 insertions(+) create mode 100644 datacube/utils/aws/inventory.py create mode 100644 datacube/utils/aws/queue.py diff --git a/datacube/utils/aws/__init__.py b/datacube/utils/aws/__init__.py index 009a4b2a8..258c64192 100644 --- a/datacube/utils/aws/__init__.py +++ b/datacube/utils/aws/__init__.py @@ -9,6 +9,7 @@ import functools import os import time +from collections.abc import Generator from typing import IO, Any, TypeAlias from urllib.parse import urlparse from urllib.request import urlopen @@ -417,6 +418,26 @@ def s3_dump(data: bytes | str | IO, url: str, s3: MaybeS3 = None, **kwargs): return 200 <= code < 300 +def s3_ls_dir(uri: str, s3: BaseClient | None = None, **kw) -> Generator[str]: + bucket, prefix = s3_url_parse(uri) + + if len(prefix) > 0 and not prefix.endswith("/"): + prefix = prefix + "/" + + s3 = s3 or s3_client() + paginator = s3.get_paginator("list_objects_v2") + + for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/", **kw): + sub_dirs = page.get("CommonPrefixes", []) + files = page.get("Contents", []) + + for p in sub_dirs: + yield f"s3://{bucket}/{p['Prefix']}" + + for o in files: + yield f"s3://{bucket}/{o['Key']}" + + def get_aws_settings( profile: str | None = None, region_name: str = "auto", diff --git a/datacube/utils/aws/inventory.py b/datacube/utils/aws/inventory.py new file mode 100644 index 000000000..17ca225a0 --- /dev/null +++ b/datacube/utils/aws/inventory.py @@ -0,0 +1,115 @@ +# This file is part of the Open Data Cube, see https://opendatacube.org for more information +# +# Copyright (c) 2015-2025 ODC Contributors +# SPDX-License-Identifier: Apache-2.0 +import csv +import json +from collections.abc import Generator, Iterable +from concurrent.futures import ThreadPoolExecutor, as_completed +from gzip import GzipFile +from io import BytesIO +from types import SimpleNamespace + +from botocore.client import BaseClient + +from . import s3_client, s3_fetch, s3_ls_dir + + +def find_latest_manifest(prefix: str, s3: BaseClient | None, **kw) -> str: + """ + Find latest manifest + """ + manifest_dirs = sorted(s3_ls_dir(prefix, s3=s3, **kw), reverse=True) + for d in manifest_dirs: + if d.endswith("/"): + leaf = d.split("/")[-2] + if leaf.endswith("Z"): + return d + "manifest.json" + return "" + + +def retrieve_manifest_files( + key: str, s3: BaseClient | None, schema: Iterable, **kw +) -> Generator[SimpleNamespace]: + """ + Retrieve manifest file and return a namespace + + namespace( + Bucket=, + Key=, + LastModifiedDate=, + Size= + ) + """ + bb = s3_fetch(key, s3=s3, **kw) + gz = GzipFile(fileobj=BytesIO(bb), mode="r") + csv_rdr = csv.reader(line.decode("utf8") for line in gz) + for rec in csv_rdr: + yield SimpleNamespace(**dict(zip(schema, rec))) + + +def list_inventory( + manifest: str, + s3: BaseClient | None = None, + prefix: str = "", + suffix: str = "", + contains: str = "", + n_threads: int | None = None, + **kw, +) -> Generator[SimpleNamespace]: + """ + Returns a generator of inventory records + + manifest -- s3:// url to manifest.json or a folder in which case latest one is chosen. + + :param manifest: + :param s3: + :param prefix: + :param suffix: + :param contains: + :param n_threads: number of threads, if not sent does not use threads + :return: SimpleNamespace + """ + # TODO: refactor parallel execution part out of this function + # pylint: disable=too-many-locals + s3 = s3 or s3_client() + + if manifest.endswith("/"): + manifest = find_latest_manifest(manifest, s3, **kw) + + info = json.loads(s3_fetch(manifest, s3=s3, **kw)) + + must_have_keys = {"fileFormat", "fileSchema", "files", "destinationBucket"} + missing_keys = must_have_keys - set(info) + if missing_keys: + raise ValueError("Manifest file haven't parsed correctly") + + if info["fileFormat"].upper() != "CSV": + raise ValueError("Data is not in CSV format") + + s3_prefix = "s3://" + info["destinationBucket"].split(":")[-1] + "/" + data_urls = [s3_prefix + f["key"] for f in info["files"]] + schema = tuple(info["fileSchema"].split(", ")) + + if n_threads: + with ThreadPoolExecutor(max_workers=n_threads) as executor: + tasks = [ + executor.submit(retrieve_manifest_files, key, s3, schema) + for key in data_urls + ] + + for future in as_completed(tasks): + for namespace in future.result(): + key = namespace.Key + if ( + key.startswith(prefix) + and key.endswith(suffix) + and contains in key + ): + yield namespace + else: + for u in data_urls: + for namespace in retrieve_manifest_files(u, s3, schema): + key = namespace.Key + if key.startswith(prefix) and key.endswith(suffix) and contains in key: + yield namespace diff --git a/datacube/utils/aws/queue.py b/datacube/utils/aws/queue.py new file mode 100644 index 000000000..b4fba90d4 --- /dev/null +++ b/datacube/utils/aws/queue.py @@ -0,0 +1,224 @@ +# This file is part of the Open Data Cube, see https://opendatacube.org for more information +# +# Copyright (c) 2015-2025 ODC Contributors +# SPDX-License-Identifier: Apache-2.0 +import itertools +import json +import logging +from collections.abc import Generator, Iterable, Iterator, Mapping +from typing import Any + +import boto3 +from toolz import dicttoolz + +_LOG: logging.Logger = logging.getLogger(__name__) + + +class ODCSQSError(Exception): + """Something wrong with ODC/AWS SQS handling""" + + +def redrive_queue( + queue_name: str, + to_queue_name: str | None = None, + limit: int | None = None, + dryrun: bool = False, + max_wait: int = 5, + messages_per_request: int = 10, +) -> int: + """ + Redrive messages from one queue to another. Default usage is to define + a "deadletter" queue, and pick its "alive" counterpart, and redrive + messages to that queue. + """ + + def post_messages(to_queue, messages) -> list: + message_bodies = [ + {"Id": str(n), "MessageBody": m.body} for n, m in enumerate(messages) + ] + to_queue.send_messages(Entries=message_bodies) + # Delete after sending, not before + for message in messages: + message.delete() + return [] + + dead_queue = get_queue(queue_name) + + if to_queue_name is not None: + alive_queue = get_queue(to_queue_name) + else: + source_queues = list(dead_queue.dead_letter_source_queues.all()) + if len(source_queues) == 0: + raise ODCSQSError( + "No alive queue found for the deadletter queue, please check your configuration." + ) + if len(source_queues) > 1: + raise ODCSQSError( + "Deadletter queue has more than one source, please specify the target queue name." + ) + alive_queue = source_queues[0] + + messages = get_messages( + dead_queue, + limit=limit, + max_wait=max_wait, + messages_per_request=messages_per_request, + ) + count_messages = 0 + approx_n_messages = dead_queue.attributes.get("ApproximateNumberOfMessages") + try: + count_messages = int(approx_n_messages) + except TypeError: + _LOG.warning("Couldn't get approximate number of messages, setting to 0") + + # If there are no messages then there's no work to do. If it's a dryrun, we + # don't do anything either. + if count_messages == 0 or dryrun: + return count_messages + + count = 0 + message_group = [] + for message in messages: + # Assume this works. Exception handling elsewhere. + message_group.append(message) + count += 1 + + if count % 10 == 0: + message_group = post_messages(alive_queue, message_group) + + # Post the last few messages + if len(message_group) > 0: + _ = post_messages(alive_queue, message_group) + + # Return the number of messages that were re-driven. + return count + + +def get_queue(queue_name: str): + """ + Return a queue resource by name, e.g., alex-really-secret-queue + """ + return boto3.resource("sqs").get_queue_by_name(QueueName=queue_name) + + +def get_queues(prefix: str | None = None, contains: str | None = None) -> Generator: + """ + Return a list of sqs queues which the user is allowed to see and filtered by + the parameters provided + """ + queues = boto3.resource("sqs").queues.all() + if prefix is not None: + queues = queues.filter(QueueNamePrefix=prefix) + + if contains is not None: + for queue in queues: + if contains in queue.attributes.get("QueueArn").split(":")[-1]: + yield queue + else: + yield from queues + + +def publish_message( + queue, message: str, message_attributes: Mapping[str, Any] | None = None +) -> None: + """ + Publish a message to a queue resource. Message should be a JSON object dumped as a + string. + """ + if message_attributes is None: + message_attributes = {} + queue.send_message( + QueueUrl=queue.url, MessageBody=message, MessageAttributes=message_attributes + ) + + +def publish_messages(queue, messages) -> None: + """ + Publish messages to a queue resource. + """ + queue.send_messages(Entries=messages) + + +def _sqs_message_stream(queue, **kw) -> Generator: + while True: + messages = queue.receive_messages(**kw) + if len(messages) == 0: + return + + yield from messages + + +def get_messages( + queue, + limit: int | None = None, + visibility_timeout: int = 60, + message_attributes: Iterable[str] | None = None, + max_wait: int = 1, + messages_per_request: int = 1, + **kw, +) -> Iterator: + """ + Get messages from SQS queue resource. Returns a lazy sequence of message objects. + + :queue: queue URL + :param limit: the maximum number of messages to return from the queue (default to all) + :param visibility_timeout: A period of time in seconds during which Amazon SQS prevents other consumers + from receiving and processing the message + :param message_attributes: Select what attributes to include in the messages, default All + :param max_wait: Longest to wait in seconds before assuming queue is empty (default: 10) + :param messages_per_request: + :**kw: Any other arguments are passed to ``.receive_messages()`` boto3 call + + :return: Iterator of sqs messages + """ + if message_attributes is None: + message_attributes = ["All"] + + messages = _sqs_message_stream( + queue, + VisibilityTimeout=visibility_timeout, + MaxNumberOfMessages=messages_per_request, + WaitTimeSeconds=max_wait, + MessageAttributeNames=message_attributes, + **kw, + ) + + if limit is None: + return messages + + if limit < 1: + raise ODCSQSError(f"Limit {limit} is not valid.") + + return itertools.islice(messages, limit) + + +def capture_attributes(action: str, stac: dict) -> dict: + """Determine SNS message attributes""" + product = dicttoolz.get_in(["properties", "odc:product"], stac) + date_time = dicttoolz.get_in(["properties", "datetime"], stac) + maturity = dicttoolz.get_in(["properties", "dea:dataset_maturity"], stac) + + if not product: + product = stac.get("collection") + + return { + "action": {"DataType": "String", "StringValue": action}, + "product": {"DataType": "String", "StringValue": product}, + "datetime": {"DataType": "String", "StringValue": date_time}, + **( + {"maturity": {"DataType": "String", "StringValue": maturity}} + if maturity + else {} + ), + } + + +def publish_to_topic(arn: str, action: str, stac: dict) -> None: + """ + Publish 'added' or 'archived' action to the provided sns topic + """ + boto3.client("sns").publish( + TopicArn=arn, + Message=json.dumps(stac), + MessageAttributes=capture_attributes(action, stac), + ) diff --git a/tests/test_utils_aws.py b/tests/test_utils_aws.py index 572a74d64..027fd67e2 100644 --- a/tests/test_utils_aws.py +++ b/tests/test_utils_aws.py @@ -3,9 +3,11 @@ # Copyright (c) 2015-2025 ODC Contributors # SPDX-License-Identifier: Apache-2.0 import json +import os from typing import Any from unittest import mock +import boto3 import botocore import moto import pytest @@ -28,6 +30,10 @@ s3_head_object, s3_url_parse, ) +from datacube.utils.aws.queue import get_queues, redrive_queue + +ALIVE_QUEUE_NAME = "mock-alive-queue" +DEAD_QUEUE_NAME = "mock-dead-queue" def _json(**kw): @@ -42,6 +48,106 @@ def mock_urlopen(text: str, code: int = 200): return m +def get_n_messages(queue) -> int: + return int(queue.attributes.get("ApproximateNumberOfMessages")) + + +@pytest.fixture +def aws_env(monkeypatch) -> None: + if "AWS_DEFAULT_REGION" not in os.environ: + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2") + + +@moto.mock_aws +def test_redrive_to_queue(aws_env: None) -> None: + resource = boto3.resource("sqs") + + dead_queue = resource.create_queue(QueueName=DEAD_QUEUE_NAME) + alive_queue = resource.create_queue( + QueueName=ALIVE_QUEUE_NAME, + Attributes={ + "RedrivePolicy": json.dumps( + { + "deadLetterTargetArn": dead_queue.attributes.get("QueueArn"), + "maxReceiveCount": 2, + } + ), + }, + ) + + # Test redriving to a queue without an alive queue specified + dead_queue.send_message(MessageBody=json.dumps({"test": 1})) + assert get_n_messages(dead_queue) == 1 + + count = redrive_queue(DEAD_QUEUE_NAME, max_wait=0) + assert count == 1 + + # Test redriving to a queue that is specified + dead_queue.send_message(MessageBody=json.dumps({"test": 2})) + assert get_n_messages(dead_queue) == 1 + + redrive_queue(DEAD_QUEUE_NAME, ALIVE_QUEUE_NAME, max_wait=0) + assert get_n_messages(dead_queue) == 1 + assert get_n_messages(alive_queue) == 2 + + # Test lots of messages: + for i in range(35): + dead_queue.send_message(MessageBody=json.dumps({"content": f"Something {i}"})) + + count = redrive_queue(DEAD_QUEUE_NAME, ALIVE_QUEUE_NAME, max_wait=0) + assert count == 35 + + assert get_n_messages(dead_queue) == 0 + + +@moto.mock_aws +def test_get_queues(aws_env: None) -> None: + resource = boto3.resource("sqs") + + resource.create_queue(QueueName="a_queue1") + resource.create_queue(QueueName="b_queue2") + resource.create_queue(QueueName="c_queue3") + resource.create_queue(QueueName="d_queue4") + + queues = get_queues() + + assert len(list(queues)) == 4 + + # Test prefix + queues = get_queues(prefix="a_queue1") + assert "queue1" in next(iter(queues)).url + + # Test prefix + queues = get_queues(contains="2") + assert "b_queue2" in next(iter(queues)).url + + # Test prefix and contains + queues = get_queues(prefix="c", contains="3") + assert "c_queue3" in next(iter(queues)).url + + # Test prefix and not contains + queues = get_queues(prefix="d", contains="5") + assert len(list(queues)) == 0 + + # Test contains and not prefix + queues = get_queues(prefix="q", contains="2") + assert len(list(queues)) == 0 + + # Test not found prefix + queues = get_queues(prefix="fake_start") + assert len(list(queues)) == 0 + + # Test not found contains + queues = get_queues(contains="not_there") + assert len(list(queues)) == 0 + + +@moto.mock_aws +def test_get_queues_empty(aws_env: None) -> None: + queues = get_queues() + assert list(queues) == [] + + def test_ec2_current_region() -> None: tests = [ (None, None),