Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions datacube/utils/aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import functools
import os
import time
from collections.abc import Generator
from typing import IO, Any, TypeAlias
from urllib.parse import urlparse
from urllib.request import urlopen
Expand Down Expand Up @@ -417,6 +418,26 @@ def s3_dump(data: bytes | str | IO, url: str, s3: MaybeS3 = None, **kwargs):
return 200 <= code < 300


def s3_ls_dir(uri: str, s3: BaseClient | None = None, **kw) -> Generator[str]:
bucket, prefix = s3_url_parse(uri)

if len(prefix) > 0 and not prefix.endswith("/"):
prefix = prefix + "/"

s3 = s3 or s3_client()
paginator = s3.get_paginator("list_objects_v2")

for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/", **kw):
sub_dirs = page.get("CommonPrefixes", [])
files = page.get("Contents", [])

for p in sub_dirs:
yield f"s3://{bucket}/{p['Prefix']}"

for o in files:
yield f"s3://{bucket}/{o['Key']}"


def get_aws_settings(
profile: str | None = None,
region_name: str = "auto",
Expand Down
115 changes: 115 additions & 0 deletions datacube/utils/aws/inventory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# This file is part of the Open Data Cube, see https://opendatacube.org for more information
#
# Copyright (c) 2015-2025 ODC Contributors
# SPDX-License-Identifier: Apache-2.0
import csv
import json
from collections.abc import Generator, Iterable
from concurrent.futures import ThreadPoolExecutor, as_completed
from gzip import GzipFile
from io import BytesIO
from types import SimpleNamespace

from botocore.client import BaseClient

from . import s3_client, s3_fetch, s3_ls_dir


def find_latest_manifest(prefix: str, s3: BaseClient | None, **kw) -> str:
"""
Find latest manifest
"""
manifest_dirs = sorted(s3_ls_dir(prefix, s3=s3, **kw), reverse=True)
for d in manifest_dirs:
if d.endswith("/"):
leaf = d.split("/")[-2]
if leaf.endswith("Z"):
return d + "manifest.json"
return ""


def retrieve_manifest_files(
key: str, s3: BaseClient | None, schema: Iterable, **kw
) -> Generator[SimpleNamespace]:
"""
Retrieve manifest file and return a namespace

namespace(
Bucket=<bucket_name>,
Key=<key_path>,
LastModifiedDate=<date>,
Size=<size>
)
"""
bb = s3_fetch(key, s3=s3, **kw)
gz = GzipFile(fileobj=BytesIO(bb), mode="r")
csv_rdr = csv.reader(line.decode("utf8") for line in gz)
for rec in csv_rdr:
yield SimpleNamespace(**dict(zip(schema, rec)))


def list_inventory(
manifest: str,
s3: BaseClient | None = None,
prefix: str = "",
suffix: str = "",
contains: str = "",
n_threads: int | None = None,
**kw,
) -> Generator[SimpleNamespace]:
"""
Returns a generator of inventory records

manifest -- s3:// url to manifest.json or a folder in which case latest one is chosen.

:param manifest:
:param s3:
:param prefix:
:param suffix:
:param contains:
:param n_threads: number of threads, if not sent does not use threads
:return: SimpleNamespace
"""
# TODO: refactor parallel execution part out of this function
# pylint: disable=too-many-locals
s3 = s3 or s3_client()

if manifest.endswith("/"):
manifest = find_latest_manifest(manifest, s3, **kw)

info = json.loads(s3_fetch(manifest, s3=s3, **kw))

must_have_keys = {"fileFormat", "fileSchema", "files", "destinationBucket"}
missing_keys = must_have_keys - set(info)
if missing_keys:
raise ValueError("Manifest file haven't parsed correctly")

if info["fileFormat"].upper() != "CSV":
raise ValueError("Data is not in CSV format")

s3_prefix = "s3://" + info["destinationBucket"].split(":")[-1] + "/"
data_urls = [s3_prefix + f["key"] for f in info["files"]]
schema = tuple(info["fileSchema"].split(", "))

if n_threads:
with ThreadPoolExecutor(max_workers=n_threads) as executor:
tasks = [
executor.submit(retrieve_manifest_files, key, s3, schema)
for key in data_urls
]

for future in as_completed(tasks):
for namespace in future.result():
key = namespace.Key
if (
key.startswith(prefix)
and key.endswith(suffix)
and contains in key
):
yield namespace
else:
for u in data_urls:
for namespace in retrieve_manifest_files(u, s3, schema):
key = namespace.Key
if key.startswith(prefix) and key.endswith(suffix) and contains in key:
yield namespace
224 changes: 224 additions & 0 deletions datacube/utils/aws/queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# This file is part of the Open Data Cube, see https://opendatacube.org for more information
#
# Copyright (c) 2015-2025 ODC Contributors
# SPDX-License-Identifier: Apache-2.0
import itertools
import json
import logging
from collections.abc import Generator, Iterable, Iterator, Mapping
from typing import Any

import boto3
from toolz import dicttoolz

_LOG: logging.Logger = logging.getLogger(__name__)


class ODCSQSError(Exception):
"""Something wrong with ODC/AWS SQS handling"""


def redrive_queue(
queue_name: str,
to_queue_name: str | None = None,
limit: int | None = None,
dryrun: bool = False,
max_wait: int = 5,
messages_per_request: int = 10,
) -> int:
"""
Redrive messages from one queue to another. Default usage is to define
a "deadletter" queue, and pick its "alive" counterpart, and redrive
messages to that queue.
"""

def post_messages(to_queue, messages) -> list:
message_bodies = [
{"Id": str(n), "MessageBody": m.body} for n, m in enumerate(messages)
]
to_queue.send_messages(Entries=message_bodies)
# Delete after sending, not before
for message in messages:
message.delete()
return []

dead_queue = get_queue(queue_name)

if to_queue_name is not None:
alive_queue = get_queue(to_queue_name)
else:
source_queues = list(dead_queue.dead_letter_source_queues.all())
if len(source_queues) == 0:
raise ODCSQSError(
"No alive queue found for the deadletter queue, please check your configuration."
)
if len(source_queues) > 1:
raise ODCSQSError(
"Deadletter queue has more than one source, please specify the target queue name."
)
alive_queue = source_queues[0]

messages = get_messages(
dead_queue,
limit=limit,
max_wait=max_wait,
messages_per_request=messages_per_request,
)
count_messages = 0
approx_n_messages = dead_queue.attributes.get("ApproximateNumberOfMessages")
try:
count_messages = int(approx_n_messages)
except TypeError:
_LOG.warning("Couldn't get approximate number of messages, setting to 0")

# If there are no messages then there's no work to do. If it's a dryrun, we
# don't do anything either.
if count_messages == 0 or dryrun:
return count_messages

count = 0
message_group = []
for message in messages:
# Assume this works. Exception handling elsewhere.
message_group.append(message)
count += 1

if count % 10 == 0:
message_group = post_messages(alive_queue, message_group)

# Post the last few messages
if len(message_group) > 0:
_ = post_messages(alive_queue, message_group)

# Return the number of messages that were re-driven.
return count


def get_queue(queue_name: str):
"""
Return a queue resource by name, e.g., alex-really-secret-queue
"""
return boto3.resource("sqs").get_queue_by_name(QueueName=queue_name)


def get_queues(prefix: str | None = None, contains: str | None = None) -> Generator:
"""
Return a list of sqs queues which the user is allowed to see and filtered by
the parameters provided
"""
queues = boto3.resource("sqs").queues.all()
if prefix is not None:
queues = queues.filter(QueueNamePrefix=prefix)

if contains is not None:
for queue in queues:
if contains in queue.attributes.get("QueueArn").split(":")[-1]:
yield queue
else:
yield from queues


def publish_message(
queue, message: str, message_attributes: Mapping[str, Any] | None = None
) -> None:
"""
Publish a message to a queue resource. Message should be a JSON object dumped as a
string.
"""
if message_attributes is None:
message_attributes = {}
queue.send_message(
QueueUrl=queue.url, MessageBody=message, MessageAttributes=message_attributes
)


def publish_messages(queue, messages) -> None:
"""
Publish messages to a queue resource.
"""
queue.send_messages(Entries=messages)


def _sqs_message_stream(queue, **kw) -> Generator:
while True:
messages = queue.receive_messages(**kw)
if len(messages) == 0:
return

yield from messages


def get_messages(
queue,
limit: int | None = None,
visibility_timeout: int = 60,
message_attributes: Iterable[str] | None = None,
max_wait: int = 1,
messages_per_request: int = 1,
**kw,
) -> Iterator:
"""
Get messages from SQS queue resource. Returns a lazy sequence of message objects.

:queue: queue URL
:param limit: the maximum number of messages to return from the queue (default to all)
:param visibility_timeout: A period of time in seconds during which Amazon SQS prevents other consumers
from receiving and processing the message
:param message_attributes: Select what attributes to include in the messages, default All
:param max_wait: Longest to wait in seconds before assuming queue is empty (default: 10)
:param messages_per_request:
:**kw: Any other arguments are passed to ``.receive_messages()`` boto3 call

:return: Iterator of sqs messages
"""
if message_attributes is None:
message_attributes = ["All"]

messages = _sqs_message_stream(
queue,
VisibilityTimeout=visibility_timeout,
MaxNumberOfMessages=messages_per_request,
WaitTimeSeconds=max_wait,
MessageAttributeNames=message_attributes,
**kw,
)

if limit is None:
return messages

if limit < 1:
raise ODCSQSError(f"Limit {limit} is not valid.")

return itertools.islice(messages, limit)


def capture_attributes(action: str, stac: dict) -> dict:
"""Determine SNS message attributes"""
product = dicttoolz.get_in(["properties", "odc:product"], stac)
date_time = dicttoolz.get_in(["properties", "datetime"], stac)
maturity = dicttoolz.get_in(["properties", "dea:dataset_maturity"], stac)

if not product:
product = stac.get("collection")

return {
"action": {"DataType": "String", "StringValue": action},
"product": {"DataType": "String", "StringValue": product},
"datetime": {"DataType": "String", "StringValue": date_time},
**(
{"maturity": {"DataType": "String", "StringValue": maturity}}
if maturity
else {}
),
}


def publish_to_topic(arn: str, action: str, stac: dict) -> None:
"""
Publish 'added' or 'archived' action to the provided sns topic
"""
boto3.client("sns").publish(
TopicArn=arn,
Message=json.dumps(stac),
MessageAttributes=capture_attributes(action, stac),
)
Loading
Loading