diff --git a/odc/geo/_interop.py b/odc/geo/_interop.py index e7c61d49..f25ea5d6 100644 --- a/odc/geo/_interop.py +++ b/odc/geo/_interop.py @@ -43,6 +43,14 @@ def datacube(self) -> bool: def tifffile(self) -> bool: return self._check("tifffile") + @property + def azure(self) -> bool: + return self._check("azure.storage.blob") + + @property + def botocore(self) -> bool: + return self._check("botocore") + @staticmethod def _check(lib_name: str) -> bool: return importlib.util.find_spec(lib_name) is not None diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py new file mode 100644 index 00000000..78c5e212 --- /dev/null +++ b/odc/geo/cog/_az.py @@ -0,0 +1,168 @@ +import base64 +from typing import Any + +from ._multipart import MultiPartUploadBase + + +class AzureLimits: + """ + Common Azure writer settings. + """ + + @property + def min_write_sz(self) -> int: + # Azure minimum write size for blocks (default is 4 MiB) + return 4 * (1 << 20) + + @property + def max_write_sz(self) -> int: + # Azure maximum write size for blocks (default is 100 MiB) + return 100 * (1 << 20) + + @property + def min_part(self) -> int: + return 1 + + @property + def max_part(self) -> int: + # Azure supports up to 50,000 blocks per blob + return 50_000 + + +class AzMultiPartUpload(AzureLimits, MultiPartUploadBase): + """ + Azure Blob Storage multipart upload. + """ + + # pylint: disable=too-many-instance-attributes + + def __init__( + self, account_url: str, container: str, blob: str, credential: Any = None + ): + """ + Initialise Azure multipart upload. + + :param account_url: URL of the Azure storage account. + :param container: Name of the container. + :param blob: Name of the blob. + :param credential: Authentication credentials (e.g., SAS token or key). + """ + self.account_url = account_url + self.container = container + self.blob = blob + self.credential = credential + + # Initialise Azure Blob service client + # pylint: disable=import-outside-toplevel,import-error + from azure.storage.blob import BlobServiceClient + + self.blob_service_client = BlobServiceClient( + account_url=account_url, credential=credential + ) + self.container_client = self.blob_service_client.get_container_client(container) + self.blob_client = self.container_client.get_blob_client(blob) + + self.block_ids: list[str] = [] + + def initiate(self, **kwargs) -> str: + """ + Initialise the upload. No-op for Azure. + """ + return "azure-block-upload" + + def write_part(self, part: int, data: bytes) -> dict[str, Any]: + """ + Stage a block in Azure. + + :param part: Part number (unique). + :param data: Data for this part. + :return: A dictionary containing part information. + """ + block_id = base64.b64encode(f"block-{part}".encode()).decode() + self.blob_client.stage_block(block_id=block_id, data=data) + self.block_ids.append(block_id) + return {"PartNumber": part, "BlockId": block_id} + + def finalise(self, parts: list[dict[str, Any]]) -> str: + """ + Commit the block list to finalise the upload. + + :param parts: List of uploaded parts metadata. + :return: The ETag of the finalised blob. + """ + # pylint: disable=import-outside-toplevel,import-error + from azure.storage.blob import BlobBlock + + block_list = [BlobBlock(block_id=part["BlockId"]) for part in parts] + self.blob_client.commit_block_list(block_list) + return self.blob_client.get_blob_properties().etag + + def cancel(self, other: str = ""): + """ + Cancel the upload by clearing the block list. + """ + assert other == "" + self.block_ids.clear() + + @property + def url(self) -> str: + """ + Get the Azure blob URL. + + :return: The full URL of the blob. + """ + return self.blob_client.url + + @property + def started(self) -> bool: + """ + Check if any blocks have been staged. + + :return: True if blocks have been staged, False otherwise. + """ + return bool(self.block_ids) + + def writer(self, kw: dict[str, Any], *, client: Any = None): + """ + Return a stateless writer compatible with Dask. + """ + return DelayedAzureWriter(self, kw) + + def dask_name_prefix(self) -> str: + """Return the Dask name prefix for Azure.""" + return "azure-finalise" + + +class DelayedAzureWriter(AzureLimits): + """ + Dask-compatible writer for Azure Blob Storage multipart uploads. + """ + + def __init__(self, mpu: AzMultiPartUpload, kw: dict[str, Any]): + """ + Initialise the Azure writer. + + :param mpu: AzMultiPartUpload instance. + :param kw: Additional parameters for the writer. + """ + self.mpu = mpu + self.kw = kw # Additional metadata like ContentType + + def __call__(self, part: int, data: bytes) -> dict[str, Any]: + """ + Write a single part to Azure Blob Storage. + + :param part: Part number. + :param data: Chunk data. + :return: Metadata for the written part. + """ + return self.mpu.write_part(part, data) + + def finalise(self, parts: list[dict[str, Any]]) -> str: + """ + Finalise the upload by committing the block list. + + :param parts: List of uploaded parts metadata. + :return: ETag of the finalised blob. + """ + return self.mpu.finalise(parts) diff --git a/odc/geo/cog/_mpu.py b/odc/geo/cog/_mpu.py index f1776d8d..ddc453a5 100644 --- a/odc/geo/cog/_mpu.py +++ b/odc/geo/cog/_mpu.py @@ -495,3 +495,49 @@ def _finalizer_dask_op( _, rr = _root.flush(write, leftPartId=1, finalise=True) return rr + + +def get_mpu_kwargs( + mk_header=None, + mk_footer=None, + user_kw=None, + writes_per_chunk=1, + spill_sz=20 * (1 << 20), + client=None, +) -> dict: + """ + Construct shared keyword arguments for multipart uploads. + """ + return { + "mk_header": mk_header, + "mk_footer": mk_footer, + "user_kw": user_kw, + "writes_per_chunk": writes_per_chunk, + "spill_sz": spill_sz, + "client": client, + } + + +def mpu_upload( + chunks: Union[dask.bag.Bag, list[dask.bag.Bag]], + *, + writer: Any, + dask_name_prefix: str, + **kw, +) -> "Delayed": + """Shared logic for multipart uploads to storage services.""" + client = kw.pop("client", None) + writer_kw = dict(kw) + if client is not None: + writer_kw["client"] = client + spill_sz = kw.get("spill_sz", 20 * (1 << 20)) + if spill_sz: + write = writer(writer_kw) + else: + write = None + return mpu_write( + chunks, + write, + dask_name_prefix=dask_name_prefix, + **kw, # everything else remains + ) diff --git a/odc/geo/cog/_multipart.py b/odc/geo/cog/_multipart.py new file mode 100644 index 00000000..c7376bfc --- /dev/null +++ b/odc/geo/cog/_multipart.py @@ -0,0 +1,87 @@ +""" +Multipart upload interface. + +Defines the `MultiPartUploadBase` class for implementing multipart upload functionality. +This interface standardises methods for initiating, uploading, and finalising +multipart uploads across storage backends. +""" + +from abc import ABC, abstractmethod +from typing import Any, Union, TYPE_CHECKING + +from dask.delayed import Delayed +from ._mpu import get_mpu_kwargs, mpu_upload + +if TYPE_CHECKING: + # pylint: disable=import-outside-toplevel,import-error + import dask.bag + + +class MultiPartUploadBase(ABC): + """Abstract base class for multipart upload.""" + + @abstractmethod + def initiate(self, **kwargs) -> str: + """Initiate a multipart upload and return an identifier.""" + + @abstractmethod + def write_part(self, part: int, data: bytes) -> dict[str, Any]: + """Upload a single part.""" + + @abstractmethod + def finalise(self, parts: list[dict[str, Any]]) -> str: + """Finalise the upload with a list of parts.""" + + @abstractmethod + def cancel(self, other: str = ""): + """Cancel the multipart upload.""" + + @property + @abstractmethod + def url(self) -> str: + """Return the URL of the upload target.""" + + @property + @abstractmethod + def started(self) -> bool: + """Check if the multipart upload has been initiated.""" + + @abstractmethod + def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any: + """ + Return a Dask-compatible writer for multipart uploads. + + :param kw: Additional parameters for the writer. + :param client: Dask client for distributed execution. + """ + + @abstractmethod + def dask_name_prefix(self) -> str: + """Return the dask name prefix specific to the backend.""" + + def upload( + self, + chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]], + *, + mk_header: Any = None, + mk_footer: Any = None, + user_kw: dict[str, Any] | None = None, + writes_per_chunk: int = 1, + spill_sz: int = 20 * (1 << 20), + client: Any = None, + ) -> Delayed: + """High-level upload that calls mpu_upload under the hood.""" + kwargs = get_mpu_kwargs( + mk_header=mk_header, + mk_footer=mk_footer, + user_kw=user_kw, + writes_per_chunk=writes_per_chunk, + spill_sz=spill_sz, + client=client, + ) + return mpu_upload( + chunks, + writer=self.writer, + dask_name_prefix=self.dask_name_prefix(), + **kwargs, + ) diff --git a/odc/geo/cog/_s3.py b/odc/geo/cog/_s3.py index 6da5cf85..0ae6981f 100644 --- a/odc/geo/cog/_s3.py +++ b/odc/geo/cog/_s3.py @@ -9,14 +9,16 @@ from cachetools import cached -from ._mpu import PartsWriter, SomeData, mpu_write +from ._mpu import PartsWriter, SomeData +from ._multipart import MultiPartUploadBase if TYPE_CHECKING: import dask.bag - import distributed from botocore.credentials import ReadOnlyCredentials from dask.delayed import Delayed + import distributed + _state: dict[str, Any] = {} @@ -68,7 +70,7 @@ def max_part(self) -> int: return 10_000 -class MultiPartUpload(S3Limits): +class S3MultiPartUpload(S3Limits, MultiPartUploadBase): """ Dask to S3 dumper. """ @@ -195,31 +197,9 @@ def writer(self, kw, *, client: Any = None) -> PartsWriter: writer.prep_client(client) return writer - # pylint: disable=too-many-arguments - def upload( - self, - chunks: "dask.bag.Bag" | list["dask.bag.Bag"], - *, - mk_header: Any = None, - mk_footer: Any = None, - user_kw: dict[str, Any] | None = None, - writes_per_chunk: int = 1, - spill_sz: int = 20 * (1 << 20), - client: Any = None, - **kw, - ) -> "Delayed": - """Upload chunks to S3 with multipart uploads.""" - write = self.writer(kw, client=client) if spill_sz else None - return mpu_write( - chunks, - write, - mk_header=mk_header, - mk_footer=mk_footer, - user_kw=user_kw, - writes_per_chunk=writes_per_chunk, - spill_sz=spill_sz, - dask_name_prefix="s3finalise", - ) + def dask_name_prefix(self) -> str: + """Return the Dask name prefix for S3.""" + return "s3finalise" def _safe_get(v, timeout=0.1): @@ -236,7 +216,7 @@ class DelayedS3Writer(S3Limits): # pylint: disable=import-outside-toplevel,import-error - def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]): + def __init__(self, mpu: S3MultiPartUpload, kw: dict[str, Any]): self.mpu = mpu self.kw = kw # mostly ContentType= kinda thing self._shared_var: Optional["distributed.Variable"] = None @@ -262,7 +242,7 @@ def _shared(self, client: "distributed.Client") -> "distributed.Variable": self._shared_var = Variable(self._build_name("MPUpload"), client) return self._shared_var - def _ensure_init(self, final_write: bool = False) -> MultiPartUpload: + def _ensure_init(self, final_write: bool = False) -> S3MultiPartUpload: # pylint: disable=too-many-return-statements mpu = self.mpu if mpu.started: diff --git a/odc/geo/cog/_tifffile.py b/odc/geo/cog/_tifffile.py index 34066293..3ab86e21 100644 --- a/odc/geo/cog/_tifffile.py +++ b/odc/geo/cog/_tifffile.py @@ -11,6 +11,7 @@ from functools import partial from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from urllib.parse import urlparse from xml.sax.saxutils import escape as xml_escape import numpy as np @@ -23,7 +24,8 @@ from ..types import Shape2d, SomeNodata, Unset, shape_ from ._mpu import mpu_write from ._mpu_fs import MPUFileSink -from ._s3 import MultiPartUpload, s3_parse_url +from ._multipart import MultiPartUploadBase + from ._shared import ( GDAL_COMP, GEOTIFF_TAGS, @@ -626,15 +628,16 @@ def save_cog_with_dask( bigtiff: bool = True, overview_resampling: Union[int, str] = "nearest", aws: Optional[dict[str, Any]] = None, + azure: Optional[dict[str, Any]] = None, client: Any = None, stats: bool | int = True, **kw, ) -> Any: """ - Save a Cloud Optimized GeoTIFF to S3 or file with Dask. + Save a Cloud Optimized GeoTIFF to S3, Azure Blob Storage, or file with Dask. :param xx: Pixels as :py:class:`xarray.DataArray` backed by Dask - :param dst: S3 url or a file path on shared storage + :param dst: S3, Azure URL, or file path :param compression: Compression to use, default is ``DEFLATE`` :param level: Compression "level", depends on chosen compression :param predictor: TIFF predictor setting @@ -643,6 +646,7 @@ def save_cog_with_dask( :param blocksize: Configure blocksizes for main and overview images :param bigtiff: Generate BigTIFF by default, set to ``False`` to disable :param aws: Configure AWS write access + :param azure: Azure credentials/config :param client: Dask client :param stats: Set to ``False`` to disable stats computation @@ -653,13 +657,12 @@ def save_cog_with_dask( from ..xr import ODCExtensionDa - if aws is None: - aws = {} + aws = aws or {} + azure = azure or {} - upload_params = {k: kw.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in kw} - upload_params.update( - {k: aws.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in aws} - ) + upload_params = { + k: kw.pop(k, None) for k in ["writes_per_chunk", "spill_sz"] if k in kw + } parts_base = kw.pop("parts_base", None) # Normalise compression settings and remove GDAL compat options from kw @@ -699,7 +702,7 @@ def save_cog_with_dask( if band_names and len(band_names) != meta.nsamples: raise ValueError( - f"Found {len(band_names)} band names ({band_names}) but there are {meta.nsamples} bands." + f"Found {len(band_names)} band names ({band_names}), expected {meta.nsamples} bands." ) layers = _pyramids_from_cog_metadata(xx, meta, resampling=overview_resampling) @@ -731,19 +734,37 @@ def save_cog_with_dask( "_stats": _stats, } - tiles_write_order = _tiles[::-1] - if len(tiles_write_order) > 4: - tiles_write_order = [ - dask.bag.concat(tiles_write_order[:4]), - *tiles_write_order[4:], - ] + # Determine output type and initiate uploader + parsed_url = urlparse(dst) + if parsed_url.scheme == "s3": + if have.botocore: + from ._s3 import S3MultiPartUpload, s3_parse_url - bucket, key = s3_parse_url(dst) - if not bucket: - # assume disk output + bucket, key = s3_parse_url(dst) + uploader: MultiPartUploadBase = S3MultiPartUpload(bucket, key, **aws) + else: + raise RuntimeError("Please install `boto3` to use S3") + elif parsed_url.scheme == "az": + if have.azure: + from ._az import AzMultiPartUpload + + assert azure is not None + assert "account_url" in azure + assert "credential" in azure + + uploader = AzMultiPartUpload( + account_url=azure["account_url"], + container=parsed_url.netloc, + blob=parsed_url.path.lstrip("/"), + credential=azure["credential"], + ) + else: + raise RuntimeError("Please install `azure-storage-blob` to use Azure") + else: + # Assume local disk write = MPUFileSink(dst, parts_base=parts_base) return mpu_write( - tiles_write_order, + _tiles[::-1], write, mk_header=_patch_hdr, user_kw={ @@ -755,15 +776,15 @@ def save_cog_with_dask( **upload_params, ) - upload_params["ContentType"] = ( - "image/tiff;application=geotiff;profile=cloud-optimized" - ) + # Upload tiles + tiles_write_order = _tiles[::-1] # Reverse tiles for writing + if len(tiles_write_order) > 4: # Optimize for larger datasets + tiles_write_order = [ + dask.bag.concat(tiles_write_order[:4]), + *tiles_write_order[4:], + ] - cleanup = aws.pop("cleanup", False) - s3_sink = MultiPartUpload(bucket, key, **aws) - if cleanup: - s3_sink.cancel("all") - return s3_sink.upload( + return uploader.upload( tiles_write_order, mk_header=_patch_hdr, user_kw={ diff --git a/odc/geo/gcp.py b/odc/geo/gcp.py index 56fade35..f72f7dd4 100644 --- a/odc/geo/gcp.py +++ b/odc/geo/gcp.py @@ -102,9 +102,16 @@ def resolution(self) -> Resolution: def points(self) -> Tuple[Geometry, Geometry]: """Return multipoint geometries for (Pixel, World).""" + pix_points: list[tuple[float, float]] = [ + (float(p[0]), float(p[1])) for p in self._pix.tolist() + ] + wld_points: list[tuple[float, float]] = [ + (float(p[0]), float(p[1])) for p in self._wld.tolist() + ] + return ( - multipoint(self._pix.tolist(), None), - multipoint(self._wld.tolist(), self.crs), + multipoint(pix_points, None), + multipoint(wld_points, self.crs), ) def __dask_tokenize__(self): diff --git a/odc/geo/geom.py b/odc/geo/geom.py index 043ac6c7..f1e89850 100644 --- a/odc/geo/geom.py +++ b/odc/geo/geom.py @@ -350,13 +350,16 @@ def qr2sample( ny = y1 - y0 pts = quasi_random_r2(n, offset=offset) s = numpy.asarray([nx, ny], dtype="float32") - edge_pts = [] + edge_pts: list[tuple[float, float]] = [] if with_edges: sample_density = numpy.sqrt(n / (nx * ny)) n_side = int(numpy.round(sample_density * min(nx, ny))) + 1 n_side = max(2, n_side) - edge_pts = self.boundary(n_side).coords[:-1] + edge_pts = [ + (float(ep[0]), float(ep[1])) + for ep in list(self.boundary(n_side).coords[:-1]) + ] if padding is None: padding = 0.3 * min(nx, ny) / (n_side - 1) @@ -368,7 +371,11 @@ def qr2sample( pts[:, 0] += x0 pts[:, 1] += y0 - return multipoint(pts.tolist() + edge_pts, self.crs) + coords: list[tuple[float, float]] = [ + (float(p[0]), float(p[1])) for p in pts.tolist() + ] + edge_pts + + return multipoint(coords, self.crs) def wrap_shapely(method): diff --git a/odc/geo/roi.py b/odc/geo/roi.py index 7e4ee423..50759ed3 100644 --- a/odc/geo/roi.py +++ b/odc/geo/roi.py @@ -284,7 +284,7 @@ def base(self) -> Shape2d: @property def chunks(self) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: """Dask compatible chunk rerpesentation.""" - y, x = (tuple(np.diff(idx).tolist()) for idx in self._offsets) + y, x = (tuple(map(int, np.diff(idx))) for idx in self._offsets) return (y, x) def locate(self, pix: SomeIndex2d) -> Tuple[int, int]: diff --git a/setup.cfg b/setup.cfg index 4f09ae56..a19ab783 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,10 +54,14 @@ tiff = s3 = boto3 +az = + azure-storage-blob + all = %(warp)s %(tiff)s %(s3)s + %(az)s test = pytest diff --git a/tests/test_az.py b/tests/test_az.py new file mode 100644 index 00000000..9462f091 --- /dev/null +++ b/tests/test_az.py @@ -0,0 +1,80 @@ +"""Tests for the Azure AzMultiPartUpload class.""" + +import base64 +from unittest.mock import MagicMock, patch + +import pytest + +pytest.importorskip("azure.storage.blob") +from odc.geo.cog._az import AzMultiPartUpload # noqa: E402 + + +@pytest.fixture +def azure_mpu(): + """Fixture for initializing AzMultiPartUpload.""" + account_url = "https://account_name.blob.core.windows.net" + return AzMultiPartUpload(account_url, "container", "some.blob", None) + + +def test_mpu_init(azure_mpu): + """Basic test for AzMultiPartUpload initialization.""" + assert azure_mpu.account_url == "https://account_name.blob.core.windows.net" + assert azure_mpu.container == "container" + assert azure_mpu.blob == "some.blob" + assert azure_mpu.credential is None + + +@patch("odc.geo.cog._az.BlobServiceClient") +def test_azure_multipart_upload(mock_blob_service_client): + """Test the full Azure AzMultiPartUpload functionality.""" + # Mock Azure Blob SDK client structure + mock_blob_client = MagicMock() + mock_container_client = MagicMock() + mock_blob_service_client.return_value.get_container_client.return_value = ( + mock_container_client + ) + mock_container_client.get_blob_client.return_value = mock_blob_client + + # Simulate return values for Azure Blob SDK methods + mock_blob_client.get_blob_properties.return_value.etag = "mock-etag" + + # Test parameters + account_url = "https://mockaccount.blob.core.windows.net" + container = "mock-container" + blob = "mock-blob" + credential = "mock-sas-token" + + # Create an instance of AzMultiPartUpload and call its methods + azure_upload = AzMultiPartUpload(account_url, container, blob, credential) + upload_id = azure_upload.initiate() + part1 = azure_upload.write_part(1, b"first chunk of data") + part2 = azure_upload.write_part(2, b"second chunk of data") + etag = azure_upload.finalise([part1, part2]) + + # Define block IDs + block_id1 = base64.b64encode(b"block-1").decode("utf-8") + block_id2 = base64.b64encode(b"block-2").decode("utf-8") + + # Verify the results + assert upload_id == "azure-block-upload" + assert etag == "mock-etag" + + # Verify BlobServiceClient instantiation + mock_blob_service_client.assert_called_once_with( + account_url=account_url, credential=credential + ) + + # Verify stage_block calls + mock_blob_client.stage_block.assert_any_call( + block_id=block_id1, data=b"first chunk of data" + ) + mock_blob_client.stage_block.assert_any_call( + block_id=block_id2, data=b"second chunk of data" + ) + + # Verify commit_block_list was called correctly + block_list = mock_blob_client.commit_block_list.call_args[0][0] + assert len(block_list) == 2 + assert block_list[0].id == block_id1 + assert block_list[1].id == block_id2 + mock_blob_client.commit_block_list.assert_called_once() diff --git a/tests/test_s3.py b/tests/test_s3.py index 8349bd81..d04a462a 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -1,9 +1,29 @@ -from odc.geo.cog._s3 import MultiPartUpload +"""Tests for odc.geo.cog._s3.""" -# TODO: moto +import unittest +from odc.geo.cog._s3 import S3MultiPartUpload +# Conditional import for S3 support +try: + from odc.geo.cog._s3 import S3MultiPartUpload + + HAVE_S3 = True +except ImportError: + S3MultiPartUpload = None + HAVE_S3 = False + + +def require_s3(test_func): + """Decorator to skip tests if s3 dependencies are not installed.""" + return unittest.skipUnless(HAVE_S3, "s3 dependencies are not installed")(test_func) + + +@require_s3 def test_s3_mpu(): - mpu = MultiPartUpload("bucket", "file.dat") - assert mpu.bucket == "bucket" - assert mpu.key == "file.dat" + """Test S3MultiPartUpload class initialization.""" + mpu = S3MultiPartUpload("bucket", "file.dat") + if mpu.bucket != "bucket": + raise ValueError("Invalid bucket") + if mpu.key != "file.dat": + raise ValueError("Invalid key")